diff --git a/.gitattributes b/.gitattributes index 0b69d72b4b0e045b50df7dfef2b0165c49aeafd9..daedcf899931fe4e542ddfe9fb151b56070362a9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -126,3 +126,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/torch/_export/error.py b/.venv/lib/python3.11/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..03b7f52fb9de435b9e58fa4a0bb141cc191e84c5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don't support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User's code contains types and functionalities we don't support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b480a9b8a44b55de0ce9ad19e0b9f2e10dec176b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b1472179d0f2ae3e2725190f979e0d4441cda77 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db1b961d34050493098db0874bd87e35d45a35ee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f1a46dad61ab58c8facfa555568ed2f77ce5790 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0abc64c07a5198bb9f39918552deca523bc4e77 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc131857ed1d25d734bce65ed9c8acad8c38ffb2614c7fcf51f2cbfebac196a1 +size 164473 diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..362ee9c8281b4e5f71108216b1694c970d457709 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..17d5ceda0ef078136248256d202737da2b44b519 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import List + +from torch._export.serde.schema import Node + + +@dataclass +class ExternKernelNode: + name: str + node: Node + + +@dataclass +class ExternKernelNodes: + nodes: List[ExternKernelNode] diff --git a/.venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py b/.venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f24822d9b07d68cecfcb207b6b0c1b6829c3d3d8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,321 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _Dim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a _Dim object. + """ + + min: int + max: Union[int, None] + derived: List[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] + dims: Dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + dims: Dict[str, Dict[str, Union[int, List[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + 'dynamic_shapes': ( + [ + ['dx', 4], + ['dx + 1', 4], + ], + ['_DimHint.STATIC'], + ['_DimHint.STATIC', '_DimHint.STATIC'], + None, + ), + 'dims': { + 'dx': { + 'min': 4, + 'max': 16, + 'derived': ['dx + 1'], + }, + }, + } + ``` + """ + dims: Dict[str, Dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined] + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, _Dim] + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.name + + assert isinstance(val, _Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, + "max": root.max, + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, Dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim + if remainder != 0: + ddim = ddim + int(remainder) + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str] + ) -> Union[None, int, _Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO + elif val == "_DimHint.STATIC": + return _DimHint.STATIC + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] + + return tree_map(deserialize_shape, dynamic_shapes) diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59b3ef847f8d55334c981edc60f3f719ded21a35 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17b816bc4895eb59e61a95ba5f4b46500339c305 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81624f6bd6334bf80a0526492c1248665787c194 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a8dcafe836d85dad1ffb8152b5d78ce01cb88d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90bb3d139ac4c6a979fe680cf9faead1678bd5e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb3ed8bb2b67362eb00bf2ba39824cbf2492c707 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e09a47537e36b424e71e07160f89450b7e7b3376 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c42d4ff756c05293039cdd9f3b1a5d3684898925 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbd162f1b42d58c6e0fc2b5950b4633350bb1067 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd04cdd09d7fa1a877e64d114db6e037849d7ba1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/__init__.py @@ -0,0 +1,89 @@ +r''' +FX is a toolkit for developers to use to transform ``nn.Module`` +instances. FX consists of three main components: a **symbolic tracer,** +an **intermediate representation**, and **Python code generation**. A +demonstration of these components in action: + +:: + + import torch + # Simple module for demonstration + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + module = MyModule() + + from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module + symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + + # High-level intermediate representation (IR) - Graph representation + print(symbolic_traced.graph) + """ + graph(): + %x : [num_users=1] = placeholder[target=x] + %param : [num_users=1] = get_attr[target=param] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp + """ + + # Code generation - valid Python code + print(symbolic_traced.code) + """ + def forward(self, x): + param = self.param + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp + """ + +The **symbolic tracer** performs "symbolic execution" of the Python +code. It feeds fake values, called Proxies, through the code. Operations +on theses Proxies are recorded. More information about symbolic tracing +can be found in the :func:`symbolic_trace` and :class:`Tracer` +documentation. + +The **intermediate representation** is the container for the operations +that were recorded during symbolic tracing. It consists of a list of +Nodes that represent function inputs, callsites (to functions, methods, +or :class:`torch.nn.Module` instances), and return values. More information +about the IR can be found in the documentation for :class:`Graph`. The +IR is the format on which transformations are applied. + +**Python code generation** is what makes FX a Python-to-Python (or +Module-to-Module) transformation toolkit. For each Graph IR, we can +create valid Python code matching the Graph's semantics. This +functionality is wrapped up in :class:`GraphModule`, which is a +:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +``forward`` method generated from the Graph. + +Taken together, this pipeline of components (symbolic tracing -> +intermediate representation -> transforms -> Python code generation) +constitutes the Python-to-Python transformation pipeline of FX. In +addition, these components can be used separately. For example, +symbolic tracing can be used in isolation to capture a form of +the code for analysis (and not transformation) purposes. Code +generation can be used for programmatically generating models, for +example from a config file. There are many uses for FX! + +Several example transformations can be found at the +`examples `__ +repository. +''' + +from .graph_module import GraphModule +from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta +from .graph import Graph, CodeGen +from .node import Node, map_arg, has_side_effect +from .proxy import Proxy +from .interpreter import Interpreter as Interpreter, Transformer as Transformer +from .subgraph_rewriter import replace_pattern diff --git a/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi b/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0a263dfc5071ddee675ec517f2cbac13b51ce9e8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi @@ -0,0 +1,15 @@ +from torch.fx._symbolic_trace import ( + symbolic_trace as symbolic_trace, + Tracer as Tracer, + wrap as wrap, +) +from torch.fx.graph import Graph as Graph +from torch.fx.graph_module import GraphModule as GraphModule +from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer +from torch.fx.node import ( + has_side_effect as has_side_effect, + map_arg as map_arg, + Node as Node, +) +from torch.fx.proxy import Proxy as Proxy +from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern diff --git a/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py b/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..27c1e600036df46ced25e722ebf493c2131beda0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, Callable, TypeVar +import textwrap + +_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {} + +_T = TypeVar("_T") + +def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: + if is_backward_compatible: + + def mark_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py b/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..2a14fce3782e9a33a0d2396f2d57fe24d7730ef2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from torch.fx import GraphModule +from torch.fx.graph_module import ( + _format_import_block, + reduce_graph_module, + reduce_package_graph_module, +) +from torch.package import PackageExporter, sys_importer + +from ._compatibility import compatibility + + +_use_lazy_graph_module_flag = False +_force_skip_lazy_graph_module_flag = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _force_skip_lazy_graph_module(): + """ + Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. + Use to skip _LazyGraphModule when testing inductor torchscript related backend. + + torch.jit.script a _LazyGraphModule results in following error: + https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 + """ + try: + global _force_skip_lazy_graph_module_flag + prior = _force_skip_lazy_graph_module_flag + _force_skip_lazy_graph_module_flag = True + yield + finally: + _force_skip_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _use_lazy_graph_module(should_use: bool): + try: + global _use_lazy_graph_module_flag + prior = _use_lazy_graph_module_flag + _use_lazy_graph_module_flag = ( + should_use and not _force_skip_lazy_graph_module_flag + ) + yield + finally: + _use_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +def _get_graph_module_cls(): + return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule + + +def _make_graph_module(*args, graph_module_cls=None, **kwargs): + if graph_module_cls is None: + graph_module_cls = _get_graph_module_cls() + + return graph_module_cls(*args, **kwargs) + + +@compatibility(is_backward_compatible=False) +class _LazyGraphModule(GraphModule): + """ + The main difference between _LazyGraphModule and GraphModule is how recompile happens. + GraphModule will do a 'recompile' call to generate python code and the forward method when it's + constructed. Later on if the graph get updated, recompile method can be called again to refresh + the saved python code and forward method. + + However in some cases especially in inductor, the recompilation can be a waste since we never + check the python code for the graph module or call its forward method. A few more concreate + examples regarding pattern matching fx passes in inductor: + 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. + 2. some passes will trace small pattern function to search it in the graph being compiled and + replace the match with the traced graph of a replacement function. The pattern graph and + replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile + for them in GraphModule.__init__ is also a waste of time. + + However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. + People may want to check the python code or call the GraphModule's forward method for debugging purposes. + + The way _LazyGraphModule solves it is, we override the recompile method to just mark the + need for recompilation but does not do the actual recompilation. Later on if people really + access the compiled python code or call the GraphModule's forward method, we do the real + recompilation. + """ + + @classmethod + def from_graphmodule(cls, gm: GraphModule): + if isinstance(gm, _LazyGraphModule): + return gm + else: + return _LazyGraphModule(gm, gm.graph) + + @staticmethod + def force_recompile(gm): + """ + Sometimes we need force a recompile as a workaround + - we want to do the real recompilation before symbolic_trace to avoid error: + https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + """ + if isinstance(gm, _LazyGraphModule): + gm.real_recompile() + + def real_recompile(self): + if self._needs_recompile(): + self._real_recompile() + + @classmethod + def _needs_recompile(cls): + return cls.forward is cls._lazy_forward + + def _lazy_forward(self, *args, **kwargs): + # Call self.real_recompile() rather than self._real_recompile() here. + # The _lazy_forward method may be saved and call repeatedly. + # Calling self.real_recompile can make sure we skip recompilation if + # we have already done so. + self.real_recompile() + assert not self._needs_recompile() + + # call `__call__` rather than 'forward' since recompilation may + # install a wrapper for `__call__` to provide a customized error + # message. + return self(*args, **kwargs) + + forward = _lazy_forward + + # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + + def __reduce_package__(self, exporter: PackageExporter): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _real_recompile(self): + return super().recompile() + + @classmethod + def recompile(cls): + cls.forward = cls._lazy_forward + + @property + def code(self) -> str: + self.real_recompile() + return super().code + + def __str__(self) -> str: + """ + str(GraphModule) will access the _code attribute. Make sure recompile + happens so _code attribute is available. + """ + self.real_recompile() + return super().__str__() diff --git a/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py b/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccbfdf048c943e811cd7d9fc092a81d2c30bcea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/_pytree.py @@ -0,0 +1,103 @@ +# mypy: allow-untyped-defs +from collections import namedtuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type + +import torch.return_types +from torch.utils._pytree import PyTree, TreeSpec + + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], List] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + + +def register_pytree_flatten_spec( + cls: Type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, + exact_structural_match=False, +) -> List[Any]: + if spec.is_leaf(): + return [pytree] + if spec.type not in SUPPORTED_NODES: + raise RuntimeError( + f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with " + "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make " + "sure that any custom pytrees have been registered before loading it.", + ) + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + if exact_structural_match: + flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type] + if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec( + pytree, + spec, + ): + raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}") + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_flatten_spec(child, child_spec, exact_structural_match) + result += flat + return result + + +def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in torch.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py b/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..6693863386513c8fd7bce53b1b3f2cbf8b0bb815 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py @@ -0,0 +1,1290 @@ +# mypy: allow-untyped-defs +import builtins +import copy +import contextlib +import functools +import inspect +import math +import os +import warnings +import collections +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch._library.fake_class_registry import FakeScriptObject + +from ._compatibility import compatibility +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: Dict[Type, None] = {} + +_is_fx_tracing_flag = False + + +def is_fx_tracing(): + return _is_fx_tracing_flag + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + if not is_fx_tracing(): + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_lnotab, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=False) +class PHWithMeta(PHBase): + """ + Object representing an input placeholder to `concrete_args` + """ + def __init__(self, ph_key: Optional[str] = None): + super().__init__() + + # Provide a hey for user to identify placeholder node during analysis + self.ph_key = ph_key + + +def _transfer_attrs(fr, to): + for attr_name in dir(fr): + attr_val = getattr(fr, attr_name) + if ( + not callable(attr_val) + and not attr_name.startswith("__") + and not hasattr(to, attr_name) + ): + setattr(to, attr_name, attr_val) + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: Set[int] = { + id(value) + for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: List[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + # Mapping of node name to module scope + self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} + + _qualname_counter: Dict[str, int] = collections.defaultdict(int) + + @compatibility(is_backward_compatible=True) + def get_fresh_qualname(self, prefix: str) -> str: + """ + Gets a fresh name for a prefix and returns it. This function ensures + that it will not clash with an existing attribute on the graph. + """ + # The idea here is that if the module doesn't have this prefix at all we + # should reset the counter to start from the beginning + # It's a ... little bit hacky (doesn't cover all cases) but the precise + # naming of the prefixes isn't a correctness issue, just a niceness + # issue + qualname = f"{prefix}0" + if not hasattr(self.root, qualname): + self._qualname_counter[prefix] = 0 + return qualname + + i = self._qualname_counter[prefix] + while True: + qualname = f"{prefix}{i}" + i += 1 + if not hasattr(self.root, qualname): + break + self._qualname_counter[prefix] = i + + return qualname + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj" + qualname = self.get_fresh_qualname(base_name) + assert isinstance(qualname, str) + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_") + assert isinstance(qualname, str) + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) + and not isinstance(m, torch.nn.Sequential) + ) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + key, _ = self.module_stack.popitem(last=True) + assert key == _scope.module_path, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: List[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + + # This covers the very specific case where we are passing in flat + # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). + # In this case, just take the concrete_args and pass them through. + name_idx = 0 + if isinstance(concrete_args, tuple) and \ + len(concrete_args) > 0 and \ + (co.co_flags & HAS_VARSTUFF) and \ + total_args == 1: + for concrete_arg in concrete_args: + out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) + if isinstance(concrete_arg, PHBase): + if concrete_arg != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=concrete_arg, to=out.node) + args.append(out) + name_idx += 1 + return root_fn, args + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = dict(zip(arg_names, concrete_args)) + + def proxy_placeholder(name): + return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) + + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if not all(child.is_leaf() for child in in_spec.children_specs): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + + # do real recompilation for _LazyGraphModule before retracing since the trace + # method can not trace the _lazy_forward method. Got error: + # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + # without this. + from torch.fx._lazy_graph_module import _LazyGraphModule + _LazyGraphModule.force_recompile(root) + + self.root = root + + assert hasattr( + type(root), self.traced_func_name + ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + if hasattr(fn, '__code__'): + code = fn.__code__ + self.graph._co_fields = { + 'co_name': code.co_name, + 'co_filename': code.co_filename, + 'co_firstlineno': code.co_firstlineno, + } + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: Dict[ + Union[ + torch.Tensor, + ScriptObject, + FakeScriptObject + ], str + ] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: Dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, # type: ignore[has-type] + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _new_patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k in {'_autowrap_search'}: + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default = ( + () + if param.default is inspect.Parameter.empty + else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if isinstance(x, PHBase): + if x != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=x, to=out.node) + + return out + # Union[int, bool] == bool in Python <= 3.6 + if ( + type(x) == bool + or type(x) in base_types + and type(x) != torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif x is None: + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default = () + else: + param = sig.parameters[name] + default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None) + ) + + +# Dictionary of (id(globals dict), function name) => globals_dict to patch for +# the purposes of the wrap() API. +# We key by the globals dict id and function name to ensure we're wrapping a given +# function only once. +_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: List[Tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + new_fn: Any + + def revert(self): + raise NotImplementedError + + def patch(self): + raise NotImplementedError + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + +class _Patcher: + def __init__(self) -> None: + super().__init__() + self.patches_made: List[_PatchedFn] = [] + self.visited: Set[int] = set() + + def patch( + self, + frame_dict: Dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) + ) + self.patches_made[-1].patch() + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +CURRENT_PATCHER: Optional[_Patcher] = None + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + torch.fx.wrap('my_custom_function') + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance( + fn_or_name, str + ), "fn_or_name must be a global function or string name" + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a*2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={'b': False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) + assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return _make_graph_module(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/.venv/lib/python3.11/site-packages/torch/fx/_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd3780fe0bb36cc47f2669a72a2b9cdc0bdcbf6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/_utils.py @@ -0,0 +1,63 @@ +# mypy: allow-untyped-defs +import sys +from typing import Dict, Optional + +import torch +from torch._logging import LazyString + + +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): + """ + Returns a LazyString that formats the graph code. + """ + + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + if "print_output" not in kwargs: + kwargs["print_output"] = False + + if "colored" in kwargs and not sys.stdout.isatty(): + kwargs["colored"] = False + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(**kwargs), + ) + ) + + +def _format_graph_code(name, filename, graph_str): + """ + Returns a string that formats the graph code. + """ + return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" + + +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]: + """ + Returns the nn_module_stack of the first call_function node. + """ + for node in graph.nodes: + if node.op == "call_function" and "nn_module_stack" in node.meta: + return node.meta["nn_module_stack"] + return None + + +def get_node_context(node, num_nodes=2) -> str: + """ + Returns a string of the last num_nodes nodes in the graph. + """ + node_contexts = [] + cur = node + for i in range(num_nodes): + node_contexts.append(cur.format_node()) + if cur.op == "root": + break + cur = cur.prev + return "\n".join(node_contexts[::-1]) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/annotate.py b/.venv/lib/python3.11/site-packages/torch/fx/annotate.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b5b5f2d376115fe25542b3b1260dcd6aaf1aaf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/annotate.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +from torch.fx.proxy import Proxy +from ._compatibility import compatibility + +@compatibility(is_backward_compatible=False) +def annotate(val, type): + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ + if isinstance(val, Proxy): + if val.node.type: + raise RuntimeError(f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice") + else: + val.node.type = type + return val + else: + return val diff --git a/.venv/lib/python3.11/site-packages/torch/fx/config.py b/.venv/lib/python3.11/site-packages/torch/fx/config.py new file mode 100644 index 0000000000000000000000000000000000000000..da5120d6edf180f7fbbe88ac342b4d0e4b383e50 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/config.py @@ -0,0 +1,6 @@ +# Whether to disable showing progress on compilation passes +# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +disable_progress = True + +# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy +verbose_progress = False diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d9080f7d160ea72cd6cbad11f566f6f04afbc9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f46fbb546ada861a8675f65b289013290693e2b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4b51349627112ac75048d8d02eec2b62b5e3d22 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37455017201dbd2370e01229e7c10ef22b033fcf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65ca078f5106f3eaf28ad7565d44a0967769e365 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb96dced824b58b20e398000808a0e16192a93d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e2bc31c6926e91435df275e103a0232844f92b2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37963e90932bacedd53065d681f02c3ac1201045 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403e6276d83302944ba28d16ac0e48dc344a2d4e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0069aba111231d84fcbf4a4d421c5879ec51cd19 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6feddf3339daf32fb621e2ea85506cf7978e6c3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54753df204d946c15a57070965465455bbad84d9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..700377250e56405fd6ab2ce8939a01d07c1931d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b22242e791979706ba5275cc5449c3ad985bef9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3302d1886bfc9da0731604ea0db9c68d6d33218 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3afdec74e46ad29e256ee579db825b4dab2d85f9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f769d6d6461cf66b06925a7949006494d66988e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbee9fedd3a4da52423ab6b146303d936d9909e0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py new file mode 100644 index 0000000000000000000000000000000000000000..9c742431857c33af22dbc1ad73b5bdfcf6124b9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py @@ -0,0 +1,27 @@ +import torch.fx + + +class BackwardState: + """ + BackwardState is used to pass Python hooks from the forwards pass + into the backwards pass in Dynamo+Compiled Autograd. + + It is created by TorchDynamo and has special handling there. + Dynamo will pass an empty BackwardState to the forwards, then populate + members on it (via setattr) only after the forwards graph is finished. + Later on, in CompileAutograd we will inline and add the needed guards + on the BackwardState. + + BackwardState is identified and has special handling in AOTAutograd. + During AOTAutograd: + 1) BackwardState is an input to the forwards graph + 2) It must only be used in the backwards + 3) It will be empty in the forwards + 4) In the forwards we add a wrapper to save it + 5) In the backwards it becomes an input + 6) There can only be one per graph + + BackwardState requires CompiledAutograd. + """ + + proxy: torch.fx.Proxy diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f901919fb45c654af2198192a1cf3157aad36bf9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py @@ -0,0 +1,88 @@ +import os +import sys +from typing import Optional + + +# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations. +translation_validation = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1" +) +# Timeout (in milliseconds) for z3 finding a solution. +# [@compile_ignored: debug] +translation_validation_timeout = int( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000") +) +# Disables bisection for translation validation. +# +# Translation validation bisection is enabled by default, if translation validation +# is also enabled. This should help finding guard simplification issues. However, +# since validation uses Z3 for bisecting, it might take a lot of time. +# +# Set this configuration option so as to avoid bisecting. +# [@compile_ignored: debug] +translation_validation_no_bisect = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1" +) +# Checks whether replaying ShapeEnv events on a freshly constructed one yields +# the a ShapeEnv with the same state. This should be used only in testing. +check_shape_env_recorded_events = False + +# TODO: Perhaps consider allowing unions for the configs below (so you can hit +# multiple reps at the same time) + +# Give extended debug information if the string representation of a guard +# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue +# this guard, we will generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_guard_added = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None +) + +# Give extended debug information when a particular symbol is allocated. For +# example, set this to "u2" and whenever we create this symbol, we will +# generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_create_symbol = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None +) + +# Give extended debug information (C++ backtrace) for all extended debug +# settings as well as errors. The C++ backtrace is slow and very spammy so we +# don't include it by default even when you're requesting extended debug. +# [@compile_ignored: debug] +extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != "" + +# Give extended debug information (line of code) when a torch function +# is called during export. This is useful for showing progress and detecting +# where export might be stuck. Currently only works for strict=False. +# [@compile_ignored: debug] +extended_debug_current_loc = ( + os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1" +) + +# [@compile_ignored: debug] Show a warning for every specialization +print_specializations = False + +# wraps (un)equalities with 'Not' class after recording the correct expression +# in the FX graph. This should incorrectly construct the divisible and replacement +# lists, and incorrectly issue guards. +inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False + +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_version_key = False + +# If we produce more than this many guards on a symbol, force the symbol to +# get specialized and bail out if this many guards mention this particular +# symbol. This may be slightly more aggressive than the true number of guards +# issued (as we test if we've hit the limit on-the-fly, whereas we may +# do further simplifications at final guard issuance time that make guards +# irrelevant.) +symbol_guard_limit_before_specialize: Optional[int] = None + +# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same. +use_duck_shape = True + +from torch.utils._config_module import install_config_module + + +install_config_module(sys.modules[__name__]) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..9b347762dedba00611fa579b5654cd9718d782b6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,1079 @@ +# mypy: allow-untyped-defs +import operator +from collections import deque +from typing import Dict, List, Set, NamedTuple, Tuple, Deque + +import torch +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes +from torch.fx.experimental.partitioner_utils import ( + Partition, + Device, + PartitionerConfig, + get_partition_to_latency_mapping, + get_latency_of_partitioned_graph, + NodeLatency, + get_extra_size_of, + PartitionMode, +) +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, map_arg +from torch.fx.passes.split_module import split_module + + +class DAGNode: + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + + def __init__( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_device_ids: List[int], + size_bytes: int, + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: List[Node] = input_nodes + self.output_nodes: List[Node] = output_nodes + self.logical_device_ids: List[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + + +class DAG: + """DAG class contains all the DAG nodes""" + + def __init__(self) -> None: + self.nodes: List[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: List[Node], + output_nodes: List[Node], + logical_devices: List[int], + size_bytes: int, + ) -> None: + node = DAGNode( + submodule_node, input_nodes, output_nodes, logical_devices, size_bytes + ) + self.nodes.append(node) + + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module""" + + dag: DAG + module_with_submodules: GraphModule + + +"""Followings are some helper functions for partition manipulation""" + + +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + + +def combine_two_partitions( + partition_0: Partition, partition_1: Partition, partitions: List[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + + +def set_parents_and_children(partitions: List[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition""" + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + + +def reorganize_partitions(partitions: List[Partition]) -> None: + """Given a list of partitions, reorganize partition id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + + +def get_bfs_level_partition(partitions: List[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: Set[Partition] = set() + visited: Set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: Set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + + +def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: + """Given a list of partitions,return node to partition mapping""" + node_to_partition: Dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + + +def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: + """Get a mapping from device logical ID to Device object.""" + logical_id_to_device: Dict[int, Device] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + return logical_id_to_device + + +def get_device_partition_stats( + partitions: List[Partition], devices: List[Device] +) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: + """Given a list of partitions and a list of devices, returns: + 1. A mapping from device to partitions on it; + 2. A mapping from device to its remaining memory size; + 3. A list of partitions that do not have a device. + """ + # logical id to device + logical_id_to_device = get_logical_id_to_device(devices) + # Track partitions on device + device_to_partitions: Dict[Device, List[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: Dict[Device, int] = {} + for d in devices: + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + for logical_id in partition.logical_device_ids: + device = logical_id_to_device[logical_id] + device_to_partitions[device].append(partition) + device_to_left_mem_bytes[device] -= partition.used_mem_bytes + else: + no_device_partitions.append(partition) + + return ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) + + +def get_device_to_partitions_mapping( + partitions: List[Partition], devices: List[Device] +): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + + def calculate_extra_mem_bytes_needed_for( + partition: Partition, partitions: List[Partition] + ): + all_nodes: Set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for( + partition, device_to_partitions[d] + ) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(partitions, devices) + + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))) + found_device = find_device_for(partition) + if not found_device: + break + return found_device + + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: Set[Partition] = {partition} + queue: Deque[Partition] = deque([partition]) + while queue: + p = queue.popleft() + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + + def __init__(self) -> None: + self.partitions: List[Partition] = [] + self.node_to_partition: Dict[Node, int] = {} + self.devices: List[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig, + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError("No devices") + # Tag the size in bytes to all nodes in the graph_module. + get_size_of_all_nodes(self.graph_module) + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): + raise RuntimeError("No Partition since no operations in the module") + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == "output": + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping, + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition( + total_size_of_graph, logical_device_id=device_with_max_mem.logical_id + ) + elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices): + raise RuntimeError("Devices have no enough memory for the module") + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all( + device.available_mem_bytes == available_mem_bytes + for device in self.devices + ): + raise RuntimeError("All devices must have same memory size!") + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + else: + self.size_based_partition() + + # Saturate host if possible. + if partitioner_config.saturate_host: + self.saturate_host() + + # Partition the graph module based on the partition assignment. + module_with_submodules = self.do_partition() + + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition( + self, total_size_of_graph, logical_device_id: int = 0 + ) -> None: + """Fit the whole fx module into one device""" + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == "output": + # Skip the output node, but there can + # be nodes after the output in certain cases. + continue + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [logical_device_id] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device("", -1, -1) + for d in self.devices: + if ( + d not in occupied_devices + and d.available_mem_bytes >= mem_size_needed + ): + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + "is too large to fit any device") + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: Dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: List[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[ + partition + ] = device.available_mem_bytes + # Update available mem for the current partition + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if ( + partition_to_left_mem_bytes[partition] + < total_size_of_input_nodes + ): + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Put the previous partitions into a list (non_single_node_partitions) + non_single_node_partitions = self.partitions[:] + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of( + node, partition.nodes + ) + partition_to_left_mem_bytes[ + partition + ] = device.available_mem_bytes + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def saturate_host(self) -> None: + """Saturate host by assigning replicates to unused devices with enough memory. + It uses a greedy approach to find a next available set of devices to place all split + partitions: For each used device, it searches for an idle device with minimal memory + size that can hold all the partition located on that device; If the search is successful + for all used devices, it then assigns the new devices' logical ID to the corresponding + partition. + """ + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(self.partitions, self.devices) + + assert ( + len(no_device_partitions) == 0 + ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + + # Devices that hold partitions + used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] + # Track replicates of the assigned devices + replicated_device_to_used_device: Dict[Device, Device] = {} + + while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( + self.devices + ): + # Success flag for this round + success = True + # Devices that have not been assigned + idle_devices = [ + d + for d in self.devices + if d not in used_devices and d not in replicated_device_to_used_device + ] + # Temporary mapping from replicated device to original device + temp_replicate_mapping = {} + + # Find a new device to replicate all partitions on an used device + for used_device in used_devices: + # Idle devices that have enough memory + available_devices = [ + d + for d in idle_devices + if d.available_mem_bytes + >= used_device.available_mem_bytes + - device_to_left_mem_bytes[used_device] + ] + if len(available_devices) == 0: + success = False + break + new_device = min(available_devices, key=lambda d: d.available_mem_bytes) + idle_devices.remove(new_device) + temp_replicate_mapping[new_device] = used_device + + if not success: + break + replicated_device_to_used_device.update(temp_replicate_mapping) + + # Update logical device IDs assigned to the partitions + for ( + replicate_device, + original_device, + ) in replicated_device_to_used_device.items(): + logical_id = replicate_device.logical_id + for partition in device_to_partitions[original_device]: + partition.logical_device_ids.append(logical_id) + for p in self.partitions: + print(p.logical_device_ids) + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node], + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules.""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == "output": + break + if node.op in {"placeholder", "get_attr"}: + continue + if node.target == operator.__getitem__: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit("_", 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node( + node, list(input_nodes), output_nodes, device_ids, size_bytes + ) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + + def combine_partitions_based_on_size( + partitions: List[Partition], available_mem_bytes: int + ) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partition used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = find_partition_to_combine_based_on_size( + sorted_partitions, available_mem_bytes, partitions + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: List[Partition], + available_mem_bytes: int, + partitions: List[Partition], + ) -> Tuple[bool, List[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boundary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == "call_module": + submodule = self.graph_module + for atom in str(node.target).split("."): + if not hasattr(submodule, atom): + raise RuntimeError( + f"Module {submodule} has no attribute {atom}" + ) + submodule = getattr(submodule, atom) + if "Embedding" in str(submodule): + return True + return False + + # Track embedding partitions and non-embedding partitions separately + embedding_partitions: List[Partition] = [] + non_embedding_partitions: List[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if ( + total_size_of_input_nodes + partition.used_mem_bytes + > available_mem_bytes + ): + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError( + node.target + "is too large to fit into a device" + ) + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = ( + "Need " + + str(len(embedding_partitions)) + + " devices, but only " + + str(len(self.devices)) + + " provided" + ) + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if ( + total_size_of_non_embedding_partitions + partition.used_mem_bytes + > available_mem_bytes + ): + raise RuntimeError( + "partition_" + + str(partition.partition_id) + + "(embedding partition) and non embedding partitions can not fit into one device" + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency], + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every beginning, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions(p0_index, p1_index, partitions) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if ( + (abs(p0.bfs_level - p1.bfs_level) <= 1) + or (p0 in p1.parents) + or p0 in (p1.children) + ): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float("inf") + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping( + partitions, self.devices + ) + if not found_deivce: + return float("inf") + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + return cost + # If two partition can not be combined, the cost is inf + return float("inf") + + def search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its corresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + if len(self.partitions) == 1: + return False + partition_pair: List[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions(i, j, self.partitions[:]) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {"placeholder", "get_attr", "output"}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: Dict[Node, NodeLatency], + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is estimated. + We first tried using n0 to swap with n2 from the other partition. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes( + n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + cost = float("inf") + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_device: + cost = float("inf") + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition( + node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float("inf") + node_pair: List[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {"placeholder", "get_attr"}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes( + node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ) + if cost < min_cost: + node_pair = [node, n1] + min_cost = cost + return cost, node_pair # type: ignore[possibly-undefined] + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: List[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: List[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [] + for n in self.graph_module.graph.nodes: + if n.op not in {"placeholder", "get_attr", "output"}: + op_nodes.append(n) + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes( + node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] + ) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition( + self, node_to_partition_mapping, partition_to_logical_device_mapping + ): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: Dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[ + partition_id + ] + else: + partition = partition_id_to_partition_mapping[ + self.node_to_partition[node] + ] + # Add the current node into the partition + partition.add_node(node) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..9f395c0ad7bdc2d86f1a24f95facf7dcbbc27a34 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/const_fold.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +import re +from typing import Callable, Dict, Optional, Set, Union + +import torch.fx +from torch.fx.node import map_arg +from torch.fx.passes.split_module import split_module + + +__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + fx_const_folded_attrs_name: Optional[str] = None, + device_for_folded_attrs: str = "cuda", + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.has_folding_been_run = False + self.fx_const_folded_attrs_name = fx_const_folded_attrs_name + self.device_for_folded_attrs = device_for_folded_attrs + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if ( + self.const_subgraph_module is None + or self.fx_const_folded_attrs_name is None + ): + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. Note that single attr const fold + # subgraphs output a single Tensor while multiple outputs are returned as + # Tuple[Tensor,]. + folded_attrs = self.const_subgraph_module() + + def _create_param(i): + return torch.nn.Parameter( + i.detach().clone() + if not isinstance(i, int) + else torch.Tensor([i]).to(device=self.device_for_folded_attrs), + requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, + ) + + params = ( + torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) + if isinstance(folded_attrs, tuple) + else _create_param(folded_attrs) + ) + setattr(self, self.fx_const_folded_attrs_name, params) + + +def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): + """ + Given `gm` and some graph module which is called with target name `inline_mod_name`, + this helper will inline all of the nodes from that called graph module into `gm`. + """ + # Fetch the inner graph module that we want to inline inside `gm`. + inline_mod = dict(gm.named_modules())[inline_mod_name] + assert isinstance(inline_mod, torch.fx.GraphModule) + call_mod_node_to_replace = None + for node in gm.graph.nodes: + if node.op == "call_module" and node.target == inline_mod_name: + call_mod_node_to_replace = node + break + assert call_mod_node_to_replace is not None + + # Now actually do the swap. Note that we have to keep track of new nodes that are + # copied into `gm` -- we do this via replacement_mapping. + call_mod_args = call_mod_node_to_replace.args + replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} + ph_count = 0 + + def replacement_fn(node): + new_node = replacement_mapping[node] + new_node.meta = node.meta.copy() + return new_node + + for inline_node in inline_mod.graph.nodes: + if inline_node.op == "placeholder": + replacement_mapping[inline_node] = call_mod_args[ph_count] + ph_count += 1 + continue + + if inline_node.op == "output": + outputs = inline_node.args[0] + output_replacements = map_arg(outputs, replacement_fn) + call_mod_node_to_replace.replace_all_uses_with(output_replacements) + continue + + with gm.graph.inserting_before(call_mod_node_to_replace): + new_node = gm.graph.node_copy(inline_node, replacement_fn) + replacement_mapping[inline_node] = new_node + + gm.graph.eliminate_dead_code() + + +def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: + """ + Make sure the name is unique (in a module) and can represents an attr. + """ + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + + return name + + +def split_const_subgraphs( + module: Union[torch.nn.Module, torch.fx.GraphModule], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + device_for_folded_attrs: str = "cpu", +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + if not isinstance(module, torch.fx.GraphModule): + mod_traced = torch.fx.symbolic_trace(module) + else: + mod_traced = module + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: Set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op != "get_attr" and not set(node.all_input_nodes).issubset( + const_nodes + ): + continue + + # If provided skip folding function says to skip, then skip. + if skip_folding_node_fn and skip_folding_node_fn(node): + continue + + # Skip folding side-effectful functions + if node.is_impure(): + continue + + # Must be a constant foldable node at this point. + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + const_mod_name, non_const_mod_name = "submod_0", "submod_1" + # Safely get submod_1 in case there are no non-const nodes + const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) + + # The module that a call_module node refers to gets copied to submodules during split. + # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to + # attach inlined modules to `split` as it's the owning module now. + for node in non_const_gm.graph.nodes if non_const_gm else []: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside const_gm, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to const_gm must be constants accessible via get_attr. + call_const_gm_args = None + for node in split.graph.nodes: + if node.op == "call_module": + if node.target == const_mod_name: + call_const_gm_args = node.args + break + assert call_const_gm_args is not None + + # Here we do the actual replacement of placeholders to get_attrs. Note that here we + # set the const_gm.graph into a new root_const_gm with split as the root module, + # because we are fetching attributes directly from the root module, instead of + # fetching them from const_gm. Example: The const_gm must have some format like: + # graph(): + # %inp : [num_users=1] = placeholder[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) + # return add + # We replace that with the following, which does not have any placeholders: + # graph(): + # %inp_1 : [num_users=1] = get_attr[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) + # return add + root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + for node in root_const_gm.graph.nodes: + if node.op == "output": + multiple_outputs = isinstance(node.args[0], tuple) + continue + if node.op != "placeholder": + continue + in_node = next(n for n in call_const_gm_args if n.name == node.target) + assert in_node.op == "get_attr" + with root_const_gm.graph.inserting_before(node): + new_node = root_const_gm.graph.get_attr(in_node.target) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + root_const_gm.graph.erase_node(node) + assert "multiple_outputs" in locals() + + # Now find the call to const_gm inside split, and replace it with a getattr to the + # folded tensor(s) that result from constant folding. Note that we don't need to + # worry about whether this is one or more tensors because the original graph + # correctly uses getitem to extract individual tensors if there are multiple folded. + fx_const_folded_attrs_name = get_unique_attr_name_in_module( + mod_traced, "_FX_CONST_FOLDED_ATTRS" + ) + setattr( + split, + fx_const_folded_attrs_name, + torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] + ) + for node in split.graph.nodes: + if node.op == "call_module" and node.target == const_mod_name: + with node.graph.inserting_before(node): + folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) + folded_attrs.meta = node.meta.copy() + node.replace_all_uses_with(folded_attrs) + break + + # Finally, inline the non-constant submod (if it exists) into the split submod. + # This is so that the original caller who may have passed in a graph module will + # get back out a graph module whose graph is traced to the same granularity. + if hasattr(split, non_const_mod_name): + _inline_module(split, non_const_mod_name) + + split.graph.eliminate_dead_code() + + return FoldedGraphModule( + split, + split.graph, + root_const_gm.graph, + fx_const_folded_attrs_name, + device_for_folded_attrs, + ) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..d3c482319f2ef290ef3e1e64d86d521e9d9324ea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/debug.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +import torch.fx as fx + +def set_trace(gm: fx.GraphModule) -> fx.GraphModule: + """ + Sets a breakpoint in `gm`'s generated python code. It drops into pdb when + `gm` gets run. + + Args: + gm: graph module to insert breakpoint. It is then recompiled for it to + take effect. + + Returns: + the `gm` with breakpoint inserted. + """ + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\n", *body] + + with gm.graph.on_generate_code( + make_transformer=lambda cur_transform: ( + # new code transformer to register + lambda body: ( + insert_pdb( + cur_transform(body) if cur_transform + else body + ) + ) + ) + ): + gm.recompile() + + return gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py new file mode 100644 index 0000000000000000000000000000000000000000..fb49795a06fac79d50b58a90a47d36bbff724b9e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/graph_gradual_typechecker.py @@ -0,0 +1,916 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +from functools import reduce +import torch +import operator +from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise +from typing import Callable, Dict +from torch.fx.node import Target, Node +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d +from torch.fx.experimental.refinement_types import Equality +import itertools + +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +import sympy + +_INFERENCE_RULES: Dict[Target, Callable] = {} +_REFINEMENT_RULES: Dict[Target, Callable] = {} +_RULES: Dict[Target, Callable] = {} + + +def expand_to_tensor_dim(t, n): + """ + Expand a type to the desired tensor dimension if possible + Raise an error otherwise. + - t is the given type + - n is a number of dimensions to expand to + """ + if t == Dyn: + dims = [Dyn] * n + return TensorType(tuple(dims)) + elif isinstance(t, TensorType): + if len(t.__args__) != n: + raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') + return t + else: + raise TypeError(f'Cannot match the type {t}') + + +def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return t1, t2 + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + s1 = len(t1.__args__) + s2 = len(t2.__args__) + + new_t1 = list(t1.__args__) + new_t2 = list(t2.__args__) + + # We make the types the same length which is the first requirement + # for consistency + if s1 > s2: + for i in range(s1 - s2): + new_t2.insert(0, 1) + + elif s2 > s1: + for i in range(s2 - s1): + new_t1.insert(0, 1) + + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor + for i, (x, y) in enumerate(zip(new_t1, new_t2)): + if x == 1: + new_t1[i] = y + elif y == 1: + new_t2[i] = x + + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation + (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) + return (t1, t2) + else: + raise TypeError(f'Cannot broadcast types {t1} and {t2}') + +def register_inference_rule(call_target): + def register(fn): + if call_target in _INFERENCE_RULES: + raise RuntimeError(f'Inference rule already registered for {call_target}!') + _INFERENCE_RULES[call_target] = fn + return fn + return register + +def register_refinement_rule(call_target): + def register(fn): + if call_target in _REFINEMENT_RULES: + raise RuntimeError(f'Refinement rule already registered for {call_target}!') + _REFINEMENT_RULES[call_target] = fn + return fn + return register + +def register_algebraic_expressions_inference_rule(call_target): + def register(fn): + if call_target in _RULES: + raise RuntimeError(f'Rule already registered for {call_target}!') + _RULES[call_target] = fn + return fn + return register + +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + t1 = n.args[0].type + t2 = n.args[1].type + + # handle scalar addition + if t1 == int and isinstance(t2, TensorType): + n.type = t2 + return n.type + + # handle scalar addition + elif t2 == int and isinstance(t1, TensorType): + n.type = t1 + return n.type + + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point + (new_t1, new_t2) = broadcast_types(t1, t2) + + if new_t1 != t1 or new_t2 != t2: + n.meta['broadcast'] = True + n.meta[str(n.args[0])] = new_t1 + n.meta[str(n.args[1])] = new_t2 + + else: + n.meta['broadcast'] = False + + new_t1 = t1 if not n.meta['broadcast'] else new_t1 + new_t2 = t2 if not n.meta['broadcast'] else new_t2 + + # we check for consistency between the new types + if is_consistent(new_t1, new_t2): + # we return the less precise type because + # broadcasting may have happened + # for operands with shape [1,2,Dyn] and [1,2,1] + # we have to assign the node [1,2,Dyn] + if is_more_precise(new_t1, new_t2): + n.type = new_t2 + else: + n.type = new_t1 + return n.type + else: + raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' + f' Types should match ') + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ + attr_node = n.args[0] + attr_name = n.args[1] + + if attr_name == "shape": + n.type = Dyn + else: + raise TypeError("Not yet implemented") + + # TODO. We leave it like this till we add a type to represent tensor sizes + return n.type + +@register_inference_rule(torch.transpose) +def transpose_inference_rule(n: Node): + """ + We check that dimensions for the transpose operations + are within range of the tensor type of the node + """ + if n.target == torch.transpose: + assert isinstance(n.args[0], Node) + t = n.args[0].type + + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + dim1, dim2 = n.args[1], n.args[2] + + if t == Dyn: + n.type = Dyn + return n.type + + elif isinstance(t, TensorType): + if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): + new_type = list(t.__args__) + new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] + final = TensorType(new_type) + n.type = get_greatest_upper_bound(n.type, final) + return n.type + else: + raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + else: + raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ + assert isinstance(n.args[0], Node) + t1 = n.args[0].type + + assert isinstance(n.args[1], list) + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) + + # if we do not know the original tensor dimension, + # we return the required dimension + if t1 == Dyn: + n.type = t2_type + return t2_type + + # if any of the dimensions are unknown, + # we check for divisibility + elif isinstance(t1, TensorType): + assert isinstance(t1, TensorType) + a = [e if e != Dyn else 1 for e in t1.__args__] + p1 = reduce(operator.mul, a) + p2 = reduce(operator.mul, t2) + if p1 % p2 == 0 or p2 % p1 == 0: + n.type = t2_type + return t2_type + else: + raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + else: + raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + +@register_inference_rule(BatchNorm2d) +def bn2d_inference_rule(n: Node, module_instance): + """ + Given a BatchNorm2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - t is consistent with t' + - x_2 is consistent with the module's num_features + - x_2' is consistent with the module's num_features + output type: the more precise type of t and t' + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + n.type = expand_to_tensor_dim(n.type, 4) + + # we check the conditions on the incoming argument + # and any existing annotation + # we also check for consistency between both annotations + if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ + is_consistent(n.type.__args__[1], module_instance.num_features) and \ + is_consistent(arg_type, n.type): + + # we choose the more precise type + # to be the node type + # so if an incoming argument has more type information + # we set this node's type to be the argument type + n.type = get_greatest_upper_bound(arg_type, n.type) + return n.type + else: + raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') + + +def calculate_out_dimension(d_in, module_instance, index): + """ + For calculating h_in and w_out according to the conv2D documentation + """ + padding = (module_instance.padding, module_instance.padding) \ + if isinstance(module_instance.padding, int) else module_instance.padding + kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ + if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size + stride = (module_instance.stride, module_instance.stride) \ + if isinstance(module_instance.stride, int) else module_instance.stride + dilation = (module_instance.dilation, module_instance.dilation) \ + if isinstance(module_instance.dilation, int) else module_instance.dilation + + DIMENSION_TYPES = (int, sympy.Symbol) + + if d_in == Dyn: + return Dyn + + elif isinstance(d_in, DIMENSION_TYPES): + n = d_in + 2 * padding[index] - \ + dilation[index] * \ + (kernel_size[index] - 1) - 1 + + return (n // stride[0]) + 1 + + else: + raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') + + +def get_greatest_upper_bound(type1, type2): + """ + Get the most precise type that's consistent with the given types + """ + if type1 == Dyn: + return type2 + elif type2 == Dyn: + return type1 + elif isinstance(type1, TensorType) and isinstance(type2, TensorType): + if not is_consistent(type1, type2): + raise TypeError(f'Inconsistent types {type1}, {type2}') + gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] + return TensorType(tuple(gub)) + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance): + """ + Given a Conv2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - x_2 is consistent with the module's in_channels + - let o = (x_1, out_channels, H_out, W_out) + then the output is the greatest upper bound of o and the existing node type t'. + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + curr_node_type = expand_to_tensor_dim(n.type, 4) + + if is_consistent(arg_type.__args__[1], module_instance.in_channels): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) + gub = get_greatest_upper_bound(new_type, curr_node_type) + n.type = gub + return n.type + else: + raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') + + +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + n.type = get_greatest_upper_bound(n.args[0].type, n.type) + return n.type + + +def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ + new_type_list = list(typ.__args__) + if len(new_type_list) == 4 or len(new_type_list) == 3: + w_in = new_type_list[-1] + h_in = new_type_list[-2] + + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + + new_type_list[-1] = w_out + new_type_list[-2] = h_out + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f'Wrong size {typ} for {module_instance}') + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool2d_inference_rule(n: Node, module_instance): + """ + Given a MaxPool2D instance and a node check the following conditions: + - Input size matches size 3 or 4 + - Current node type is consistent with the output type we will calculate + - Input size matches output size and the last two dimensions of the output + are w_out and h_out. The remaining dimensions are the same as the input + - Our final result is the greatest upper bound of the output we calculate + and the current node type. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output = maxpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output, n.type) + return n.type + + + +def linear_check(tensor_type, module_instance): + """ + Checks that an input tensor type satisfies the conditions for linear operation + and returns the output type based on in and out features given by module_instance + """ + if len(tensor_type.__args__) >= 2: + if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): + new_type_args = list(tensor_type.__args__) + new_type_args[-1] = module_instance.out_features + return TensorType(tuple(new_type_args)) + else: + raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') + else: + raise TypeError(f'Type {tensor_type} must have rank 2 or more.') + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = linear_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output_type, n.type) + return n.type + + +def adaptiveavgpool2d_check(tensor_type, module_instance): + output_size = module_instance.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + elif isinstance(output_size, tuple): + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = output_size[1] + if output_size[1] is None: + output_size[1] = output_size[0] + + new_type_list = list(tensor_type.__args__) + + if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: + new_type_list[-1] = output_size[1] + new_type_list[-2] = output_size[0] + + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptiveavgpool2d_inference_rule(n: Node, module_instance): + """ + The input and output sizes should be the same except for the last + two dimensions taken from the input, which represent width and height + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(n.type, output_type) + return n.type + +def flatten_check(tensor_type, start_dim, end_dim): + l = len(tensor_type.__args__) + + start_dim = l if start_dim == -1 else abs(start_dim) + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: + my_args = list(tensor_type.__args__) + lhs = my_args[0:start_dim] + rhs = my_args[end_dim:] + mid = my_args[start_dim:end_dim] + if Dyn in mid: + mid = [Dyn] + else: + mid = [reduce(operator.mul, my_args[start_dim:end_dim])] + new_type_list = lhs + mid + rhs + return TensorType(tuple(new_type_list)) + else: + raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + output_type = flatten_check(n.args[0].type, start_dim, end_dim) + n.type = get_greatest_upper_bound(output_type , n.type) + + return n.type + +class GraphTypeChecker: + def __init__(self, env, traced): + self.env = env + self.traced = traced + + def type_check(self): + """ + A gradual type checker for graphs + Effect: every node's field type will be + populated with a type after type-checking is done + """ + graph = self.traced.graph + + # type check every node with gradual type rules + # if any node does not type check return false + for n in graph.nodes: + self.type_check_node(n) + return True + + def type_check_node(self, n: Node): + """ + Type check a given fx node. + Current operations: + - Reshape + - Transpose + - Add + - Relu + - conv2d + - batchnorm2d + - flatten + - maxpool2d + - adaptiveavgpool2d + - linear + """ + if n.type is None: + n.type = Dyn + + if n.op == 'placeholder': + return n.type + + elif n.op == 'get_attr': + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] + if isinstance(t.data, torch.Tensor): + n.type = TensorType(t.data.shape) + return n.type + + elif n.op == 'call_function': + if n.target == getattr: + assert getattr in _INFERENCE_RULES + return _INFERENCE_RULES[n.target](n, self.traced) + + elif n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, module_instance) + else: + raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + + elif n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") + + +@register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(torch.nn.Linear) +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + +@register_refinement_rule(BatchNorm2d) +@register_refinement_rule(torch.nn.ReLU) +def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[i], args2[i]) for i in range(len(args1))] + return res + + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + +@register_refinement_rule(torch.add) +@register_refinement_rule(operator.add) +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ + res = [] + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + arg_type1 = n.args[0].type + arg_type2 = n.args[1].type + if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): + args1, args2 = broadcast_types(arg_type1, arg_type2) + # by this point, we know that args1 and args2 are the same size. + a1 = args1.__args__ + a2 = args2.__args__ + a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions + r = [] + for x, y, z in zip(a1, a2, a3): + if x == y: + r.append(Equality(x, z)) + res = r + return res + + +@register_refinement_rule(torch.flatten) +def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ + assert isinstance(n.args[0], Node) + + eq_const = [] + + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): + l = len(n.type.__args__) + arg_type = n.args[0].type + start_dim = l if start_dim == -1 else start_dim + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): + eq_const.append(Equality(t1, t2)) + + for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): + eq_const.append(Equality(t1, t2)) + return eq_const + + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + +class Refine: + """ + Symbolic shape inference. + Generates constraints over type variables. + Currently all constraints are equality constraints. + """ + def __init__(self, traced): + self.constraints = [] + self.traced = traced + self.symbol_iter = itertools.count(start=0, step=1) + + def refine(self): + """ + Generates constraints for + every node in the graph based on + the operation. + """ + graph = self.traced.graph + for n in graph.nodes: + self.refine_node(n) + return True + + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + + def replace_dyn_with_fresh_var(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if typ == Dyn: + new_symbol = Var(next(self.symbol_iter)) + return new_symbol + elif isinstance(typ, TensorType): + new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ + + def refine_node(self, n: Node): + """ + Returns a list of equality constraints for + call_module and call_function nodes. + Models the relation between input and output dimensions + using constraints in case they are both tensors. + All operations used in resnet50 are defined. + """ + if n.type is None: + n.type = Dyn + + n.type = self.replace_dyn_with_fresh_var(n.type) + + if n.op == 'call_function': + if n.target in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[n.target](n) + else: + pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[type(module_instance)](n) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + def infer_symbolic_relations(self, n: Node): + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == 'call_function': + if n.target in _RULES: + return _RULES[n.target](n) + else: + pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a634b2602a06bcadbd039815e32c9ae226cf65 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,172 @@ +# mypy: allow-untyped-defs +import torch + +from torch.fx.node import Node +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.passes.tools_common import legalize_graph +import itertools +import operator + +from typing import Dict, List, Tuple + + +def split_result_tensors( + result: torch.Tensor, inputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, ...]: + """ + A free function for use in the merge_matmul graph transformation below that + splits the output from a merged matmul into the individual results for each + input tensor. + + Arguments: + result: The merged matmul result tensor. + inputs: The list of inputs that were merged into one for the matmul. + + Returns: + List of matmul results for each input tensor. + """ + # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we + # need an int even when tracing + if isinstance(result, torch.fx.Proxy): + splits = [0] * len(inputs) + else: + splits = [x.shape[0] for x in inputs] + + return torch.split(result, splits) + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: List[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: Dict[Node, List[Node]] = {} + lhs_users: Dict[Node, List[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_split = gm.graph.call_function( + split_result_tensors, (merge_mm, lhs), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint() + return gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2c9b11ab76b15ae310a6b4ea9e89cf26cf3670 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/meta_tracer.py @@ -0,0 +1,269 @@ +# mypy: allow-untyped-defs +import torch +import torch.fx +import warnings +import functools +import builtins + +from typing import Any, Callable, Dict, Optional, Union + +def embedding_override(self, input): + return torch.empty(*input.shape, self.weight.shape[-1], device='meta') + + +def nn_layernorm_override(self, input): + return input + + +def torch_relu_override(x): + return x + + +def torch_nn_relu_override(self, x): + return x + + +def functional_relu_override(x, inplace=False): + assert not inplace, 'dont support inplace functional.relu for metatensor analysis' + return x + + +def torch_where_override(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') + + +def torch_abs_override(input, *, out=None): + assert out is None, 'Dont support in-place abs for MetaTensor analysis' + return input + +manual_meta_overrides : Dict[Callable, Callable] = { + torch.nn.Embedding: embedding_override, + torch.nn.LayerNorm: nn_layernorm_override, + torch.relu: torch_relu_override, + torch.nn.functional.relu: functional_relu_override, + torch.nn.ReLU: torch_nn_relu_override, + torch.where: torch_where_override, + torch.abs: torch_abs_override, +} + +def gen_constructor_wrapper(target): + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy('call_function', target, args, kwargs) + else: + return target(*args, **kwargs) + return wrapper, target + +class MetaProxy(torch.fx.Proxy): + def install_tensor_meta(self, tensor_meta): + self._tensor_meta = tensor_meta + + def size(self, dim=None): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.size(*[dim] if dim else []) + return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) + + def dim(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.dim() + return self.tracer.create_proxy('call_method', 'dim', (self,), {}) + + @property + def shape(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.shape + return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) + + @property + def dtype(self): + if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + return self._tensor_meta.dtype + return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) + + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, 'device') + + def __getattr__(self, k): + if k == '_tensor_meta': + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return MetaAttribute(self, k) + +class MetaAttribute(MetaProxy): + def __init__(self, root, attr: str): + + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + +class MetaDeviceAttribute(MetaAttribute): + pass + +def proxys_to_metas(v): + if isinstance(v, MetaDeviceAttribute): + return 'meta' + if isinstance(v, torch.fx.Proxy): + assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' + assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' + return v._tensor_meta + return v + +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods : bool = True + + _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] + + def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): + rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) + + if kind == 'placeholder' and target in self.meta_args: + rv.install_tensor_meta(self.meta_args[target]) + return rv + + if target in self.orig_fns: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if 'device' in kwargs: + kwargs['device'] = 'meta' + + try: + args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) + kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) + + if kind == 'call_function': + meta_target = manual_meta_overrides.get(target, target) + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == 'call_method': + meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == 'call_module': + assert hasattr(self, 'orig_forward') + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in manual_meta_overrides: + meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type] + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == 'get_attr': + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split('.') + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + assert isinstance(attr_itr, torch.Tensor) + meta_out = attr_itr.to(device='meta') + finally: + self._disable_module_getattr = False + else: + return rv + + # TODO + assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet' + rv.install_tensor_meta(meta_out) + except Exception as e: + warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') + + return rv + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, '_disable_module_getattr', False): + return attr_val + else: + return super().getattr(attr, attr_val, parameter_proxy_cache) + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + idx = 0 + mod_name = mod.__class__.__name__.lower() + path = f"{mod_name}_{idx}" + while hasattr(self.root, path): + path = f"{mod_name}_{idx}" + idx += 1 + + self.root.add_module(path, mod) + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + try: + return super().path_of_module(mod) + except NameError as e: + if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: + path = self._insert_module_as_submodule(mod) + self.prev_module = path + return path + raise + + def proxy(self, node): + return MetaProxy(node, self) + + def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + assert isinstance(meta_args, dict) + self.meta_args = meta_args + + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + try: + graph = super().trace(root, concrete_args) + graph._tracer_extras = {'meta_args': meta_args} + return graph + finally: + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + + +def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], + meta_args : Optional[Dict[str, torch.Tensor]] = None, + concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: + tracer = MetaTracer() + graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + gm = torch.fx.GraphModule(tracer.root, graph, name) + return gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..30b076a72bee2e4bcb01bd25ed4f7d73b757edb5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/normalize.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs +import operator +from typing import Any, Callable, Dict, Tuple, Optional + +import torch +import torch.fx +import torch.fx as fx +from torch.fx import Transformer, Proxy +from torch.fx.node import Argument, Target, Node, map_aggregate +from torch.fx.operator_schemas import ( + normalize_module, + normalize_function, + create_type_hint, +) + +from .schema_type_annotation import AnnotateTypesWithSchema + + +class NormalizeArgs(Transformer): + """ + Normalize arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and rewritten to exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. Also populates default + values. Does not support positional-only parameters or varargs + parameters (*args, **kwargs). + + If the nodes have 'type' metadata, it will use it to disambiguate + overloads. Otherwise, it will throw an error. + + Example usage: + m = torchvision.models.resnet18() + traced = torch.fx.symbolic_trace(m) + traced = NormalizeArgs(traced).transform() + """ + + def __init__( + self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True + ): + super().__init__(module) + self.node_map: Dict[Proxy, Node] = {} + self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs + + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + def get_type(arg): + if isinstance(arg, fx.Node): + return n.meta["type"] if "type" in n.meta else None + return type(arg) + + arg_types = map_aggregate(n.args, get_type) + assert isinstance(arg_types, tuple) + arg_types = tuple([create_type_hint(i) for i in arg_types]) + kwarg_types = {k: get_type(v) for k, v in kwargs.items()} + if n.op == "call_function": + out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) + else: + out = super().run_node(n) + if n.op != "output": + self.node_map[out] = n + out.node.meta = n.meta + out.node.type = n.type + return out + + def call_function( + self, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + arg_types: Optional[Tuple[Any, ...]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + ): + assert callable(target) + new_args_and_kwargs = normalize_function( + target, + args, # type: ignore[arg-type] + kwargs, + arg_types, # type: ignore[arg-type] + kwarg_types, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return self.tracer.create_proxy( + "call_function", target, new_args, new_kwargs + ) + else: + return super().call_function(target, args, kwargs) + + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): + assert isinstance(target, str) + new_args_and_kwargs = normalize_module( + self.module, + target, + args, # type: ignore[arg-type] + kwargs, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return super().call_module(target, new_args, new_kwargs) + else: + return super().call_module(target, args, kwargs) + + +class NormalizeOperators(AnnotateTypesWithSchema): + """ + Normalize callsites that are different ways of "spelling" the same + invocation into a single, canonical call. Currently supports: + + 1. Normalize operators (e.g. operator.add) to the `torch` ops they + ultimately invoke (e.g. torch.add) when it is possible to statically + reason that + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = NormalizeOperators(traced).transform() + """ + + binary_magic_method_remap: Dict[ + Callable[[Any, Any], Any], Callable[[Any, Any], Any] + ] = { + torch.add: operator.add, + torch.mul: operator.mul, + torch.sub: operator.sub, + torch.div: operator.truediv, + torch.floor_divide: operator.floordiv, + torch.remainder: operator.mod, + torch.eq: operator.eq, + torch.ne: operator.ne, + torch.lt: operator.lt, + torch.le: operator.le, + torch.gt: operator.gt, + torch.ge: operator.ge, + } + + def call_function( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): + # Normalize operators according to the magic methods implemented on tensors here: + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + + assert callable(target) + + if target in self.binary_magic_method_remap: + if len(args) != 2: + return super().call_function(target, args, kwargs) + lhs, rhs = args + + return super().call_function( + target=self.binary_magic_method_remap[target], + args=(lhs, rhs), + kwargs={}, + ) + + return super().call_function(target, args, kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..8362c0cb88ac167c4c7eb5409128489b9c3d75e9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/optimization.py @@ -0,0 +1,409 @@ +# mypy: allow-untyped-defs +import torch.fx as fx +from torch.fx.node import Argument, Target +from torch.nn.utils.fusion import fuse_conv_bn_eval +from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.fx.passes.shape_prop import ShapeProp +import copy +from collections import defaultdict +import torch.utils.mkldnn as th_mkldnn +import operator +import time +import logging +from enum import Enum + +def _parent_name(target : str) -> Tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit('.', 1) + return parent[0] if parent else '', name + +# Works for length 2 patterns with 2 modules +def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): + if len(node.args) == 0: + return False + nodes: Tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != 'call_module': + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + modules[node.target] = new_module + setattr(modules[parent_name], name, new_module) + +def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: + """ + Fuses convolution/BN layers for inference purposes. Will deepcopy your + model by default, but can modify the model inplace as well. + """ + patterns = [(nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d)] + if not inplace: + model = copy.deepcopy(model) + if not no_trace or not isinstance(model, torch.fx.GraphModule): + fx_model = fx.symbolic_trace(model) + else: + fx_model = model + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + if not bn.track_running_stats: + continue + fused_conv = fuse_conv_bn_eval(conv, bn) + replace_node_module(node.args[0], modules, fused_conv) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) + +def remove_dropout(model: nn.Module) -> nn.Module: + """ + Removes all dropout layers from the module. + """ + fx_model = fx.symbolic_trace(model) + + class DropoutRemover(torch.fx.Transformer): + def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if isinstance(self.submodules[target], nn.Dropout): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + return DropoutRemover(fx_model).transform() + +def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): + """ + Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. + """ + new_graph = fx.Graph() + env: Dict[fx.Node, fx.Node] = {} + for input in inputs: + new_node = new_graph.placeholder(input.name) + env[input] = new_node + for node in nodes: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + new_graph.output([env[output] for output in outputs]) + new_graph.lint() + return fx.GraphModule(orig_module, new_graph) + +mkldnn_supported = [ + nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, + torch.relu, torch.transpose, torch.sigmoid, + F.relu, F.avg_pool2d, F.adaptive_avg_pool2d +] +# These are operators that may not be convertible into MKLDNN ops (e.g. the +# args are scalar values). Thus, we only include them in the subgraph if their +# arguments are already in MKLDNN. +# TODO: Determine whether this can be removed after type inference. +mkldnn_supported_unknown = [operator.add, operator.mul] +mkldnn_map = { + nn.Conv2d: th_mkldnn.MkldnnConv2d, + nn.Linear: th_mkldnn.MkldnnLinear, + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) +} + + +def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): + """ + For each node, if it's a module that can be preconverted into MKLDNN, + then we do so and create a mapping to allow us to convert from the MKLDNN + version of the module to the original. + """ + old_modules: Dict[nn.Module, nn.Module] = {} + for node in nodes: + if node.op == 'call_module': + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in mkldnn_map: + new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) + assert isinstance(new_module, nn.Module) + old_modules[new_module] = copy.deepcopy(cur_module) + replace_node_module(node, modules, new_module) + return old_modules + +def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): + """ + Maps each module that's been changed with `modules_to_mkldnn` back to its + original. + """ + for node in nodes: + if node.op == 'call_module': + assert (isinstance(node.target, str)) + cur_module = modules[node.target] + if cur_module in old_modules: + replace_node_module(node, modules, old_modules[cur_module]) + +class MklSubgraph: + def __init__(self, fx_graph: fx.Graph): + self.fx_graph = fx_graph + self.nodes: List[fx.Node] = [] + self.start_nodes: List[fx.Node] = [] + self.end_nodes: List[fx.Node] = [] + +def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): + """ + This generates a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by running it with the example_inputs. + + Example usage: + heuristic = gen_mkl_autotuner(example_inputs, iters=10) + fast_model = optimization.optimize_for_inference(model, heuristic) + """ + fx_model = None + old_modules = None + + def use_mkl_heuristic(graph: MklSubgraph) -> bool: + nonlocal fx_model, old_modules + input_nodes = graph.start_nodes + if fx_model is None: + fx_model = graph.fx_graph.owning_module + old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] + ShapeProp(fx_model).propagate(example_inputs) + sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] + output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) + submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) + + def benchmark(f): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + out = f() + return time.time() - begin + + mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) + + reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) + no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) + return mkl_time < no_mkl_time + return use_mkl_heuristic + +def use_mkl_length(graph: MklSubgraph) -> bool: + """ + This is a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by checking if there + are more than 2 nodes in it + """ + return len(graph.nodes) > 2 + +class UnionFind: + def __init__(self, n): + self.parent: List[Optional[int]] = [None] * n + self.size: List[int] = [0] * n + + def make_set(self, v: int): + self.parent[v] = v + self.size[v] = 1 + + def find(self, v: int) -> int: + par = self.parent[v] + if v == par: + return v + assert par is not None + self.parent[v] = self.find(par) + return cast(int, self.parent[v]) + + def join(self, a: int, b: int): + a, b = self.find(a), self.find(b) + if a == b: + return a + if self.size[a] < self.size[b]: + a, b = b, a + self.parent[b] = a + self.size[a] += self.size[b] + +def optimize_for_inference( + model: torch.nn.Module, + pass_config: Optional[Dict[str, Any]] = None, + tracer: Type[fx.Tracer] = fx.Tracer +) -> torch.nn.Module: + """ + Performs a set of optimization passes to optimize a model for the + purposes of inference. Specifically, the passes that are run are: + 1. Conv/BN fusion + 2. Dropout removal + 3. MKL layout optimizations + + The third optimization takes a function `use_mkl_heuristic` that's used + to determine whether a subgraph should be explicitly run in MKL layout. + + Note: As FX does not currently handle aliasing, this pass currently + assumes nothing aliases. If that isn't true, use at your own risk. + """ + default_pass_config = { + "conv_bn_fuse": True, + "remove_dropout": True, + "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, + } + if pass_config is None: + pass_config = {} + default_pass_config.update(pass_config) + + if default_pass_config["conv_bn_fuse"]: + model = fuse(model) + if default_pass_config["remove_dropout"]: + model = remove_dropout(model) + if default_pass_config["mkldnn_layout_optimize"] is False: + return model + if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): + raise RuntimeError("mkldnn_layout_optimize config is not a dict") + if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: + raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") + use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] + + cur_tracer = tracer() + fx_graph = cur_tracer.trace(copy.deepcopy(model)) + fx_model = fx.GraphModule(cur_tracer.root, fx_graph) + modules: Dict[str, nn.Module] = dict(model.named_modules()) + + class MklSupport(Enum): + NO = 1 + YES = 2 + UNKNOWN = 3 + + # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. + # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. + # However, if it's in `mkldnn_supported_unknown`, then we only treat it as + # a MKLDNN node if its inputs are MKLDNN nodes. + for node in list(fx_graph.nodes): + supports_mkldnn = MklSupport.NO + if node.op == 'call_module': + cur_module = modules[node.target] + if type(cur_module) in mkldnn_supported: + supports_mkldnn = MklSupport.YES + sample_parameter = next(cur_module.parameters(), None) + if sample_parameter is not None: + assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" + assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" + elif node.op == 'call_function': + if node.target in mkldnn_supported: + supports_mkldnn = MklSupport.YES + elif node.target in mkldnn_supported_unknown: + supports_mkldnn = MklSupport.UNKNOWN + + if supports_mkldnn != MklSupport.NO: + if supports_mkldnn == MklSupport.UNKNOWN: + if not any(arg.target == 'to_dense' for arg in node.args): + continue + with fx_graph.inserting_before(node): + mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) + + node.args = cast(Tuple[fx.node.Argument], mkldnn_args) + + with fx_graph.inserting_after(node): + dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) + node.replace_all_uses_with(dense_x) + dense_x.args = (node,) + + # Does pre-conversion of all modules into MKLDNN (when possible) + old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) + fx_graph.old_modules = old_modules # type: ignore[attr-defined] + + # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b + for node in fx_graph.nodes: + if node.op == 'call_method' and node.target == 'to_dense': + prv_node = node.args[0] + users = list(node.users) + for user in users: + if user.op == 'call_method' and user.target == 'to_mkldnn': + user.replace_all_uses_with(prv_node) + fx_graph.erase_node(user) + if len(node.users) == 0: + fx_graph.erase_node(node) + + + num_nodes = len(fx_graph.nodes) + uf = UnionFind(num_nodes) + + def get_color(n): + if hasattr(n, 'color'): # Current node is part of a MKL subgraph + return uf.find(n.color) + if hasattr(n, 'start_color'): # Current node is input to MKL subgraph + return uf.find(n.start_color) + return None + + + # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists + # of input nodes (which are only `to_mkldnn` calls), output nodes + # (`to_dense` calls), and intermediate nodes, which are run entirely on + # MKLDNN layout tensors. + # + # Specifically, this code does a flood fill on a directed acyclic graph + # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). + # If every node only had one input, this would be sufficient. However, in + # the case that a node has multiple inputs coming from different start + # nodes (i.e. colors), we need to join these 2 colors into 1. That's done + # using a Disjoint Set Union. + for cur_idx, node in enumerate(fx_graph.nodes): + if node.op == 'call_method' and node.target == 'to_mkldnn': + node.start_color = cur_idx + uf.make_set(cur_idx) + elif node.op == 'call_method' and node.target == 'to_dense': + assert get_color(node.args[0]) is not None + node.end_color = get_color(node.args[0]) + else: + cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] + + if len(cur_colors) == 0: + continue + assert not any(i is None for i in cur_colors) + cur_colors = sorted(cur_colors) + node.color = cur_colors[0] + for other_color in cur_colors[1:]: + uf.join(cur_colors[0], other_color) + + + mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + for node in fx_graph.nodes: + if hasattr(node, 'color'): + mkldnn_graphs[uf.find(node.color)].nodes.append(node) + if hasattr(node, 'start_color'): + mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) + if hasattr(node, 'end_color'): + mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) + + + # Now that we have all the subgraphs, we need to decide which MKLDNN + # subgraphs we actually want to keep in MKLDNN. + for graph in mkldnn_graphs.values(): + if not use_mkl_heuristic(graph): + for node in graph.start_nodes + graph.end_nodes: + prv = node.args[0] + node.replace_all_uses_with(prv) + fx_graph.erase_node(node) + reset_modules(graph.nodes, modules, old_modules) + + mkldnn_conversions = 0 + for node in fx_graph.nodes: + if node.target == 'to_mkldnn' or node.target == 'to_dense': + mkldnn_conversions += 1 + + logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) + fx_graph.lint() + result = fx.GraphModule(model, fx_graph) + return result diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..796c65a4302284b94ffd3bf9bd0bf9dfcab496f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,318 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import NamedTuple, Dict, List, Set + +from torch.fx.node import Node, map_arg + + +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + + def __init__(self, partition_id: int) -> None: + self.nodes: Set[Node] = set() + self.partition_id = partition_id + self.parents: Set[Partition] = set() + self.children: Set[Partition] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: List[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {"placeholder", "get_attr"}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all( + n not in self.nodes for n in input_node.users + ) and input_node.op in {"placeholder", "get_attr"}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + + +class PartitionerConfig(NamedTuple): + devices: List[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0.0 + node_to_latency_mapping: Dict[Node, NodeLatency] = {} + node_to_partition_mapping: Dict[Node, int] = {} + partition_to_logical_device_mapping: Dict[int, List[int]] = {} + # Saturate host by replicating partitions to the remaining idle devices. + saturate_host: bool = False + + +def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError("node has no size_bytes attr") + # Don't forget the op node itself + size_bytes = getattr(node, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError("node has no size_bytes attr") + return total_size_of_input_nodes + + +def get_latency_of_one_partition( + partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> List[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: List[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {"placeholder", "get_attr"}: + continue + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any( + n in partition.nodes and n.op not in {"placeholder", "get_attr"} + for n in input_nodes + ): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + max( + node_latency.computer_latency_sec, node_latency.mem_latency_sec + ) + # Update the mem latency of this path + mem_latency_sec = ( + partition_latency.mem_latency_sec + node_latency.mem_latency_sec + ) + # Update the compute latency of this path + computer_latency_sec = ( + partition_latency.computer_latency_sec + node_latency.computer_latency_sec + ) + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper( + n, + PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ), + ) + if ( + new_partition_latency.overall_latency_sec + > max_latency.overall_latency_sec + ): + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ) + + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper( + node, + PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ), + ) + if ( + partition_latency.overall_latency_sec + > critical_path_latency.overall_latency_sec + ): + critical_path_latency = partition_latency + return critical_path_latency + + +def get_partition_to_latency_mapping( + partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] +) -> Dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition( + partition, node_to_latency_mapping + ) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + + +def get_comm_latency_between( + parent_partition: Partition, + child_partition: Partition, + transfer_rate_bytes_per_sec: float, +): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if ( + parent_partition.logical_device_ids != [] + and child_partition.logical_device_ids != [] + and parent_partition.logical_device_ids == child_partition.logical_device_ids + ): + return 0.0 + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + + +def get_latency_of_partitioned_graph( + partitions: List[Partition], + partition_to_latency_mapping: Dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float, +): + """Given all partitions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions""" + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec + children = partition.children + if partition.children: + max_latency_sec = 0.0 + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between( + partition, child, transfer_rate_bytes_per_sec + ) + new_latency_sec = dfs_helper( + child, latency_so_far_sec + comm_latency_sec + ) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: List[Partition]) -> List[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + top_partitions = [] + for partition in partitions: + # If a partition has no parents, then it is a top partition + if len(partition.parents) == 0: + top_partitions.append(partition) + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0.0 + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.0) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5771d2ae8f7a2446eb58fa28973a8fbb629145 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py @@ -0,0 +1,2215 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import inspect +import logging +import operator +import traceback +import typing +import typing_extensions +import warnings +import weakref +from collections import defaultdict +from contextlib import contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Mapping, + Optional, + overload, + Protocol, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Concatenate, ParamSpec, Self +from weakref import WeakKeyDictionary + +import torch +import torch._ops +import torch.fx as fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import SymBool, SymInt, Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_impls import fast_detach +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + is_fake, + unset_fake_temporarily, +) +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx import GraphModule, Proxy, Tracer +from torch.fx.graph_module import _assign_attr +from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.nn import Module +from torch.overrides import TorchFunctionMode +from torch.utils._python_dispatch import ( + _disable_infra_mode, + _push_mode, + _unset_infra_mode, + TorchDispatchMode, +) +from torch.utils._stats import count +from torch.utils._thunk import Thunk +from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary + +from ._backward_state import BackwardState +from .sym_node import SymNode + + +if TYPE_CHECKING: + import types + from collections.abc import MutableMapping + + import sympy + + from torch._ops import OpOverload + from torch.fx._symbolic_trace import PHBase + from torch.types import IntLikeType + +__all__ = [ + "PythonKeyTracer", + "dispatch_trace", + "make_fx", + "DecompositionInterpreter", + "py_sym_types", + "get_innermost_proxy_mode", + "get_proxy_mode", + "handle_sym_dispatch", + "maybe_enable_thunkify", + "maybe_disable_thunkify", +] + +_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] + +_AnyScriptObject = (torch.ScriptObject, FakeScriptObject) +_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject] + +aten = torch.ops.aten +prim = torch.ops.prim + +log = logging.getLogger(__name__) +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + +CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} + +CONSTANT_NUMEL_LIMIT = 1 + +T = TypeVar("T") +U = TypeVar("U") +_P = ParamSpec("_P") +R = TypeVar("R") + +null_ctx_type = type(nullcontext) +# We currently convert all SymInt to proxies before we use them. +# This could plausibly be handled at the Dynamo level. +pytree.register_pytree_node( + torch.Size, + lambda xs: (list(xs), None), + lambda xs, _: tuple(xs), + flatten_with_keys_fn=lambda xs: ( + [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], + None, + ), +) + + +def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + + +@contextmanager +def decompose( + decomposition_table: Optional[Mapping[OpOverload, Callable]] +) -> Generator[Mapping[OpOverload, Callable], None, None]: + global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + + +# ensure we cannot collide with other properties +proxy_slot = object() + + +class _NoDefault: + pass + + +no_default = _NoDefault() + +from torch.types import py_sym_types, PySymType + + +class _HasMeta(Protocol): + meta: Dict[str, PySymType] + + +def is_sym_node(node: _HasMeta) -> bool: + assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta" + return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) + + +@overload +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: + ... + + +@overload +def set_proxy_slot( + obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy +) -> None: + ... + + +@overload +def set_proxy_slot( + obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType +) -> None: + ... + + +def set_proxy_slot( + obj: Union[PySymType, _AnyScriptObjectType, Tensor], + tracer: _ProxyTracer, + proxy: object, +) -> None: + log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) + if isinstance(obj, Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + assert isinstance(proxy, _ProxyTensor) + tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, (_AnyScriptObject)): + # We DO want to clobber proxies, with a similar rationale as for tensors. + assert isinstance(proxy, Proxy) + tracer.script_object_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, py_sym_types), type(obj) + if obj not in tracer.symnode_tracker: + tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) + + # WAR: python test/dynamo/test_subclasses.py + # TestNestedTensor.test_basic_autograd + # + # AOTAutograd doesn't pass the "outer sizes" as an actual argument + # to make_fx, but it is made use of internally in AOTAutograd's + # call to tensor unflatten. Because the outer sizes isn't passed + # as an argument, it is therefore untracked. However, it turns + # out you luck out, because *Dynamo* will manually add the outer + # sizes as an argument so you can fix up the proxy'ness. + # + # This is probably fixed in + # https://github.com/pytorch/pytorch/pull/125941/ + import sympy + + if isinstance(obj.node.expr, sympy.Symbol): + tracer.sympy_expr_tracker[obj.node.expr] = proxy + + +def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: + assert isinstance(obj, (Tensor, SymNode)), type(obj) + return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) + + +_PySymProxyType = Thunk[Proxy] + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, +) -> _ProxyTensor: + ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, +) -> Union[_ProxyTensor, U]: + ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_ProxyTensor], R], +) -> Union[R, U]: + ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, +) -> Proxy: + ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, +) -> Union[Proxy, U]: + ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[Proxy], R], +) -> Union[R, U]: + ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, +) -> _PySymProxyType: + ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: T, +) -> Union[T, _PySymProxyType]: + ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_PySymProxyType], R], +) -> Union[R, U]: + ... + + +# the default argument is what to return if the slot is not set. +# the transform argument is handy if you need to extract a subfield from +# the successfully looked up result (but NOT the default.) +def get_proxy_slot( + obj: Union[Tensor, _AnyScriptObjectType, PySymType], + tracer: _ProxyTracer, + default: object = no_default, + transform: Callable = lambda x: x, +) -> object: + tracker: Any + if isinstance(obj, Tensor): + tracker = tracer.tensor_tracker + elif isinstance(obj, _AnyScriptObject): + tracker = tracer.script_object_tracker + else: + assert isinstance(obj, py_sym_types), type(obj) + tracker = tracer.symnode_tracker + + if obj not in tracker: + # Last ditch + if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: + value = tracer.sympy_expr_tracker[obj.node.expr] + else: + if isinstance(default, _NoDefault): + raise RuntimeError( + f"{obj} ({id(obj)})is not tracked with proxy for {tracer}" + ) + return default + else: + value = tracker[obj] + res = transform(value) + return res + + +def snapshot_fake(val: Tensor) -> Optional[Tensor]: + # val.detach() will also eventually call fast_detach(), + # but this saves us a full trip into __torch_dispatch__ + # (snapshot_fake is called a lot) + if isinstance(val, FakeTensor): + return fast_detach(val.fake_mode, val) + else: + return val.detach() + + +_ExtractValType = Optional[ + Union[ + PySymType, + _AnyScriptObjectType, + BackwardState, + List["_ExtractValType"], + Tuple["_ExtractValType", ...], + Dict[str, "_ExtractValType"], + Tensor, + int, + float, + bool, + ] +] + + +def extract_val(val: _ExtractValType) -> _ExtractValType: + if is_fake(val): + return snapshot_fake(val) + elif isinstance(val, py_sym_types): + return val + elif isinstance(val, _AnyScriptObject): + return val + elif isinstance(val, BackwardState): + return val + elif isinstance(val, (list, tuple)): + return val.__class__([extract_val(x) for x in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, Tensor): + if not val.is_sparse: + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + return torch.empty_strided( + val.shape, val.stride(), device=val.device, dtype=val.dtype + ) + else: + return None + elif isinstance(val, (int, float, bool)): + return val + elif val is None: + return None + + typing_extensions.assert_never(val) + + +@contextmanager +def _enable_thunkify( + tracer: _ProxyTracer, *, enable: bool = True +) -> Generator[None, None, None]: + """ + Enable thunkification inside the context manager. Thunkification prevents + SymNode computation from directly being traced into an FX graph; instead, + the compute is only added to the graph if it is actually used. This helps + us track SymNode compute when it is computed (since we need /something/ + to put in the tracker) even if it is unlikely to be used. + """ + old = tracer.enable_thunkify + tracer.enable_thunkify = enable + try: + yield + finally: + tracer.enable_thunkify = old + + +@contextmanager +def maybe_disable_thunkify() -> Generator[None, None, None]: + """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` + for more details. This is helpful if you have a wrapper function which + you want to enable thunkification on, but in some segment on the inside (say, + the original user function), you want to disable thunkification as you know + it is not needed there. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer, enable=False): + yield + else: + yield + + +@contextmanager +def maybe_enable_thunkify() -> Generator[None, None, None]: + """Within this context manager, if you are doing make_fx tracing, we will thunkify + all SymNode compute and avoid tracing it into the graph unless it is actually needed. + You should prefer to avoid using this as much as possible, as lazy evaluation of + SymNode tracing can lead to long chains of thunks which will stack overflow + if you evaluate them. However, this is currently sometimes necessary as there + are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error + due to insufficient tracing of SymNode computation. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer): + yield + else: + yield + + +# Note [invariants for node meta 'val'] +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) +def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: + proxy.node.meta["val"] = extract_val(val) + + with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + elif isinstance(val, Tensor) and not val.is_sparse: + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + return proxy + + +def thunkify( + tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs +) -> Thunk[R]: + """ + Delays computation of f until it's called again + Also caches the result + """ + if tracer.enable_thunkify: + return Thunk(functools.partial(f, *args, **kwargs)) + else: + r = f(*args, **kwargs) + return Thunk(lambda: r) + + +def track_tensor( + tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer +) -> None: + def try_set_proxy_slot( + outer_s: IntLikeType, + proxy_callable: Callable[Concatenate[PySymType, _P], Proxy], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + assert callable(proxy_callable) + if isinstance(outer_s, SymInt): + with _enable_thunkify(tracer): + set_proxy_slot( + outer_s, + tracer, + thunkify(tracer, proxy_callable, outer_s, *args, **kwargs), + ) + + # The basic idea is that we need to associate each tensor/SymInt + # with a Proxy. How do we setup this association? We just store + # the proxy on the proxy slot of the object, keyed on the tracer + # (so that if we have multiple tracers at the same time, they + # don't clobber each other.) + for i, s in enumerate(tensor.shape): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_size.int, (proxy, i), {} + ), + x, + ), + i, + ) + + if not is_sparse_any(tensor): + for i, s in enumerate(tensor.stride()): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {} + ), + x, + ), + i, + ) + + try_set_proxy_slot( + tensor.numel(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_numel.default, (proxy,), {} + ), + x, + ), + ) + if not is_sparse_any(tensor): + try_set_proxy_slot( + tensor.storage_offset(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", + torch.ops.aten.sym_storage_offset.default, + (proxy,), + {}, + ), + x, + ), + ) + set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) + + +_NestedProxys = Union[ + Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"] +] +_NestedTensors = Union[ + Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"] +] + + +def track_tensor_tree( + inner_res: T, + proxy_res: _NestedProxys, + *, + constant: Optional[_NestedTensors], + tracer: _ProxyTracer, +) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. + _set_unbacked_bindings(inner_res, proxy_res) + + def wrap_with_proxy( + e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] + ) -> None: + if isinstance(e, Tensor): + assert isinstance(proxy, Proxy) + assert constant is None or isinstance(constant, Tensor) + track_tensor(e, proxy, tracer=tracer, constant=constant) + set_meta(proxy, e) + elif isinstance(e, py_sym_types): + assert isinstance(proxy, Proxy) + # NB: eagerly set meta here, so that the numbering is in order + set_meta(proxy, e) + set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) + elif isinstance(e, _AnyScriptObject): + assert isinstance(proxy, Proxy) + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) + elif isinstance(e, (tuple, list)): + # example use case: allreduce_ returns ([tensor], work) + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + def get_constant( + c: Optional[_NestedTensors], idx: int + ) -> Optional[_NestedTensors]: + if c is None: + return None + else: + assert isinstance(c, (list, tuple)) + return c[idx] + + for idx, ee in enumerate(e): + # Use an indexer here - if proxy is a List then it will unwrap + # it. If it's a Proxy then it will proxy the getelem. + wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx)) # type: ignore[index] + + elif isinstance(e, dict): + # example use case: triton_kernel_wrapper takes arguments as kwargs + + # In theory we could support const-prop when proxy-tensor-tracing + # operators that returns dicts of tensors, but we have no use case + # for it today (since the only op we currently trace that can + # return a dict is triton_kernel_wrapper_functional/mutation, + # which does not participate in const-prop) + assert constant is None + + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + for key, val in e.items(): + wrap_with_proxy(val, proxy[key], None) # type: ignore[index] + + elif isinstance(e, BackwardState): + assert isinstance(proxy, Proxy) + set_meta(proxy, e) + e.proxy = proxy + else: + # intentionally pass on primitives + pass + + wrap_with_proxy(inner_res, proxy_res, constant) + + return inner_res + + +@dataclass +class _ProxyTensor: + proxy: Proxy + constant: Optional[Tensor] + + +def fetch_sym_proxy( + tracer: _ProxyTracer, +) -> Callable[[PySymType], Union[bool, int, float, Proxy]]: + def inner(e: PySymType) -> Union[int, bool, float, Proxy]: + n = e.node + if n.constant is not None: + return n.constant + if e.node.expr.is_number: + if isinstance(e, SymBool): + return bool(e.node.expr) + elif isinstance(e, SymInt): + return int(e.node.expr) + return float(e.node.expr) + else: + assert isinstance(e, py_sym_types) + # NB: we REQUIRE all symints to be tracked + return get_proxy_slot(e, tracer).force() + + return inner + + +@overload +def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]: + ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: _AnyScriptObjectType +) -> Union[Proxy, _AnyScriptObjectType]: + ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: PySymType +) -> Union[_PySymProxyType, PySymType]: + ... + + +def fetch_object_proxy( + tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] +) -> object: + return get_proxy_slot(t, tracer, t) + + +HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor) + + +def _maybe_record_pointwise_barrier( + func: object, proxy_mode: ProxyTorchDispatchMode +) -> None: + """ + Records pointwise operators in user program (non decomposed) that were output in fp16/bf16 + """ + if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts: + return + + if ( + not isinstance(func, torch._ops.OpOverload) + or torch.Tag.pointwise not in func.tags + ): + return + + last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes))) + t = last_node.meta.get("val") + if not isinstance(t, torch.Tensor) or t.dtype not in ( + torch.bfloat16, + torch.float16, + ): + return + + last_node.meta["low_precision_pointwise_barrier"] = True + + +def proxy_call( + proxy_mode: ProxyTorchDispatchMode, + func: OpOverload, + pre_dispatch: bool, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + unrecognized_types: List[Type] = [] + flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) + + def can_handle_tensor(x: Tensor) -> bool: + r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) + if proxy_mode._allow_fake_constant: + r = r or type(x) in (torch._subclasses.FakeTensor,) + if not r: + unrecognized_types.append(type(x)) + return r + + # If there are any tensor subclasses, we need to handle those tensor subclasses first + # TODO: we could use types to test this + if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)): + not_implemented_log.debug( + "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", + unrecognized_types, + ) + return NotImplemented + + r = maybe_handle_decomp(proxy_mode, func, args, kwargs) + if r is not NotImplemented: + _maybe_record_pointwise_barrier(func, proxy_mode) + return r + + # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. + if not pre_dispatch and func not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ]: + with proxy_mode: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + tracer = proxy_mode.tracer + f_flat_args_kwargs = [ + ( + fetch_object_proxy(tracer, x) + if isinstance(x, (Tensor, _AnyScriptObject)) + else x + ) + for x in flat_args_kwargs + ] + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + not any( + t.constant is None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) + ) + + if torch.Tag.data_dependent_output in func.tags: + # Check if all of the Tensor inputs are constants + if all_constant: + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + with unset_fake_temporarily(): + return func(*const_args, **const_kwargs) + # If any of the Tensor inputs are "real" (not FakeTensor), we may + # incorrectly burn in constants by allowing this access. Raise + # an error in this case + if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only( + Tensor, lambda t: not is_fake(t), (args, kwargs) + ): + raise RuntimeError( + f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar. " + "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " + "in your make_fx call." + ) + + proxy_flat_args_kwargs = [ + e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs + ] + proxy_flat_args_kwargs = [ + (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e) + for e in proxy_flat_args_kwargs + ] + proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) + + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + if func is torch.ops.aten.lift_fresh.default: + func = torch.ops.aten.lift_fresh_copy.default + + proxy_out = proxy_mode.tracer.create_proxy( + "call_function", + func, + proxy_args, + proxy_kwargs, + name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), + ) + + with _enable_thunkify(proxy_mode.tracer): + out = func(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + any_constant = any( + t.constant is not None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + + constant = None + + def tensor_numel_in_limit(t: Tensor) -> bool: + return t.numel() <= CONSTANT_NUMEL_LIMIT + + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + if ( + func is torch.ops.aten.lift_fresh_copy.default + and out.numel() <= CONSTANT_NUMEL_LIMIT + ): + with unset_fake_temporarily(): + assert isinstance(args[0], (Proxy, Tensor)), type(args[0]) + constant = args[0].clone() + elif ( + torch.Tag.nondeterministic_seeded not in func.tags + and all_constant + and any_constant + and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out) + ): + # NB: do NOT include factories as constants + with unset_fake_temporarily(): + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + constant = func(*const_args, **const_kwargs) + else: + constant = None + + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + _maybe_record_pointwise_barrier(func, proxy_mode) + return out + + +class _SymNodeDict: + """ + Wrapper around a dictionary that will hash SymInts with their nodes + """ + + def __init__(self) -> None: + self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {} + + def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: + self.sym_node_dict[key.node] = value + + def __getitem__(self, key: PySymType) -> _PySymProxyType: + return self.sym_node_dict[key.node] + + def __contains__(self, key: PySymType) -> bool: + return key.node in self.sym_node_dict + + def get( + self, key: PySymType, default: Optional[_PySymProxyType] = None + ) -> _PySymProxyType: + # dict.get()'s annotation doesn't accept `None` when the value type + # isn't Optional. + return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type] + + def __iter__(self) -> Any: + raise NotImplementedError + + def __len__(self) -> int: + return len(self.sym_node_dict) + + +class PythonKeyTracer(Tracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: _SymNodeDict + sympy_expr_tracker: Dict[sympy.Symbol, object] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + torch_fn_counts: Dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self) -> None: + super().__init__(autowrap_modules=()) # type: ignore[arg-type] + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + self.sympy_expr_tracker = dict() + + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + self.enable_thunkify = False + + # In general, we don't want to make modules leaves. In principle, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the traced graph. + def call_module( + self, + m: Module, + forward: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> Any: + return forward(*args, **kwargs) + + # We don't want to turn getattr calls into proxies. So we just return the actual value. + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] + ) -> object: + return attr_val + + def create_arg(self, a: object) -> fx.node.Node: + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + + qualname = self.get_fresh_qualname("_param_constant") + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + elif isinstance(a, py_sym_types): + assert a.node.constant is not None + return a.node.constant + return super().create_arg(a) # type: ignore[return-value] + + @overload + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: + ... + + @overload + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: + ... + + @overload + def unwrap_proxy( + self, e: _AnyScriptObjectType + ) -> Union[Proxy, _AnyScriptObjectType]: + ... + + def unwrap_proxy(self, e: T) -> object: + if isinstance(e, Tensor): + return get_proxy_slot(e, self, e, lambda x: x.proxy) + elif isinstance(e, py_sym_types): + return get_proxy_slot(e, self, e, lambda e: e.force()) + elif isinstance(e, _AnyScriptObject): + return get_proxy_slot(e, self, e) + else: + return e + + +@contextmanager +def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode + + temp_elements = [] + pre_dispatch_mode = None + + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, PreDispatchTorchFunctionMode): + pre_dispatch_mode = mode + break + else: + temp_elements.append(mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + try: + yield + + finally: + if pre_dispatch_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(pre_dispatch_mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + +@torch._disable_dynamo +def dispatch_trace( + root: Union[Module, Callable], + tracer: Tracer, + concrete_args: Optional[Tuple[Any, ...]] = None, +) -> GraphModule: + graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] + + # NB: be careful not to DCE .item() calls + def impure_pred(n: fx.Node) -> bool: + from .symbolic_shapes import is_accessor_node + + # Always defer to the built-in notion of impure + if n.is_impure(): + return True + + # Accessors always OK to DCE + if is_accessor_node(n): + return False + + # If the operator in question takes SymInt args to SymInt output, + # we assume it's pure and OK to DCE + if ( + isinstance(n.meta.get("val"), py_sym_types) + and + # NB: constant args ok + all( + isinstance(a.meta.get("val"), py_sym_types) + for a in n.args + if isinstance(a, fx.Node) + ) + ): + return False + + # No idea, just assume it's not OK + return True + + graph.eliminate_dead_code(impure_pred) + from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints + + dedupe_symints(graph) + name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ + return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) + + +def wrap_key( + f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool +) -> Callable[_P, R]: + flat_tensors, tensors_spec = pytree.tree_flatten(tensors) + + @functools.wraps(f) + def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: + flat_proxies, proxies_spec = pytree.tree_flatten(proxies) + assert len(flat_proxies) == len(flat_tensors) + with disable_proxy_modes_tracing() as m: + assert isinstance(m, ProxyTorchDispatchMode) + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + + def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: + return get_proxy_slot(t, tracer, t, lambda x: x.proxy) + + out = f(*tensors) + out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) + out = pytree.tree_map_only( + _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out + ) + + def get_sym_proxy_slot(t: PySymType) -> Proxy: + return get_proxy_slot(t, tracer).force() + + out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out) + return out + + return wrapped + + +# TODO: Make downstream users of this work with OperatorBase +ORIGINAL_ATEN: Optional[object] = None + + +@contextmanager +def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]: + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta["original_aten"] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta["original_aten"] = None + else: + yield + + +class TorchFunctionMetadataMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + + def __torch_function__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...] = (), + kwargs: Optional[Dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + self.tracer.torch_fn_metadata = func + self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 + return func(*args, **kwargs) + + +# This mode is **only** used for pre_dispatch tracing. +# In particular, we need to make sure that autograd/autocast API's +# that do not desugar into dispatcher operators stay in the graph. +class PreDispatchTorchFunctionMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + + def __torch_function__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...] = (), + kwargs: Optional[Dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + if func in _side_effectful_need_to_be_preserved_pre_dispatch: + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # instead of hardcoding it here. + node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] + if func is torch._C._set_grad_enabled: + node.meta["val"] = None + return node + # Don't actually run the function! We just want to trace the calls + # into a graph. We don't actualy want to change global autograd state. + return func(*args, **kwargs) + + +class ProxyTorchDispatchMode(TorchDispatchMode): + # Ensure this is read-only; this exists only for legacy reasons + @property + def enable_tracing(self) -> bool: + return True + + def __init__( + self, + tracer: _ProxyTracer, + tracing_mode: str, + pre_dispatch: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + ) -> None: + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None + super().__init__(dk) + self.tracer = tracer + self.tracing_mode = tracing_mode + self.pre_dispatch = pre_dispatch + self._allow_fake_constant = _allow_fake_constant + self._error_on_data_dependent_ops = _error_on_data_dependent_ops + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] + self.decomp_layers = 0 + from torch._inductor import config + + self.emulate_precision_casts = config.emulate_precision_casts + + @count + def __torch_dispatch__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...] = (), + kwargs: Optional[Dict[str, object]] = None, + ) -> object: + with set_original_aten_op(func): + kwargs = kwargs or {} + + if func in (prim.device.default,): + return func(*args, **kwargs) + + return proxy_call(self, func, self.pre_dispatch, args, kwargs) + + def __enter__(self) -> Self: + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + self.enter_stack.append(maybe_prev_proxy_mode) + return super().__enter__() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: + b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + _push_mode(mb_previous_proxy_mode) + + return b + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + def _compute_proxy( + self, func: OpOverload, args: Tuple[object, ...], out: PySymType + ) -> Proxy: + n_args = tuple( + get_proxy_slot(a, self.tracer).force().node + if isinstance(a, py_sym_types) + else a + for a in args + ) + + # func doesn't have a __torch_function__ that Proxy can interpose, so + # we gotta do it manually + n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type] + p_out = fx.Proxy(n_out, self.tracer) + set_meta(p_out, out) + return p_out + + def __sym_dispatch__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> object: + # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! + if func == operator.mul: + if isinstance(args[1], int) and args[1] == 1: + return args[0] + elif isinstance(args[0], int) and args[0] == 1: + return args[1] + + # For speed, we assume there are no nested data structures + # (otherwise we could use tree_map) + # We also assume there are no keyword arguments. + assert not kwargs + out = func(*args, **kwargs) + + # If func returned a constant, we don't need to trace; we have + # determined that the result is constant (no matter if the inputs + # were symbolic) and it is no longer necessary to trace the + # computation. This could occur if func triggered some guards. + if isinstance(out, py_sym_types): + p_out_thunk = thunkify( + self.tracer, self._compute_proxy, func=func, args=args, out=out + ) + set_proxy_slot(out, self.tracer, p_out_thunk) + + return out + + +class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: MutableMapping[PySymType, _PySymProxyType] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + sympy_expr_tracker: Dict[sympy.Symbol, object] + torch_fn_metadata: Optional[OpOverload] + torch_fn_counts: Dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self, graph: fx.graph.Graph) -> None: + super().__init__(graph) + self.symnode_tracker = weakref.WeakKeyDictionary() + self.tensor_tracker = WeakTensorKeyDictionary() + self.sympy_expr_tracker = {} + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + + +# TODO: I'm not sure what the point of this class is; you can just +# make_fx through a regular Interpreter +class DecompositionInterpreter(fx.Interpreter): + def __init__( + self, + module: fx.GraphModule, + new_graph: fx.Graph, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + **kwargs: object, + ) -> None: + super().__init__(module, **kwargs) # type: ignore[arg-type] + self.new_graph = new_graph + self.tracer = _GraphAppendingTracerEx(self.new_graph) + # Blegh + self.decomposition_table = decomposition_table or {} + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder( + self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + ) -> object: + out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + # TODO handle case where the first character of target is '*' + return out + + def get_attr( + self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + ) -> object: + out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + # call_function, call_method, call_module get traced automatically by the outer mode. + + def output( + self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] + ) -> object: + out = super().output(target, args, kwargs) # type: ignore[arg-type] + + def get_proxy_node(x: _ProxyTensor) -> fx.node.Node: + return x.proxy.node + + def unwrap(e: Tensor) -> Union[Tensor, fx.Node]: + return get_proxy_slot(e, self.tracer, e, get_proxy_node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args: object, **kwargs: object) -> object: + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) # type: ignore[arg-type] + + +def wrapper_and_args_for_make_fx( + func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object] +) -> Tuple[Callable[[List[object]], R], List[object]]: + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(flat_args: List[object]) -> R: + fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) + return func(*fn_args, **fn_kwargs) + + return wrapped, flat_args + + +@contextmanager +def disable_autocast_cache() -> Generator[None, None, None]: + old_value = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + try: + yield + finally: + torch.set_autocast_cache_enabled(old_value) + + +class _ModuleNotInstalledAsSubmoduleError(NameError): + pass + + +# Base class for inline _ModuleStackTracer.__init__.AttrProxy +class _AttrProxy: + def reset_proxy_mapping(self, base: Module, path: str) -> None: + pass + + +class _ModuleStackTracer(PythonKeyTracer): + r"""Customized version of PythonKeyTracer that retains module stack + information in node.meta["nn_module_stack"]. + + FX symbolic trace actually does this already, but it relies on `self.root` + being the actual module being traced. Since make_fx traces a lambda of our + creation, things don't work properly. + + So for this version we hold onto a reference to the original module + (scope_root) and use that to match the path. Also when we see, + A + / \ + B C + \ / + D + we want to record the path as A.B.D by recording only one path. + See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 + """ + + def __init__(self, scope_root: GraphModule) -> None: + super().__init__() + self.scope_root = scope_root + self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() + self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary() + self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary() + self.counter = 0 + + self.module_id_cache = defaultdict(list) + for name, mod in self.scope_root.named_modules(remove_duplicate=False): + self.module_id_cache[id(mod)].append(name) + + # Build a wrapper around _AttrProxy to provide the tracer. We can't + # store it on _AttrProxy itself beceause we mimic the underlying class + # (including its attributes). + tracer = self + + class AttrProxy(_AttrProxy): + def __init__(self, base: Module, path: str) -> None: + # Class is modified to be a subclass of torch.nn.Module + # Warning: We blow away our own attributes here to mimic the base class + # - so don't expect `self.x` to do anything useful. + self.__class__ = type( + base.__class__.__name__, + (self.__class__, base.__class__), + {}, + ) + self.__dict__ = base.__dict__ + self.__class__.__module__ = base.__class__.__module__ + self.__class__.__qualname__ = base.__class__.__qualname__ + self.reset_proxy_mapping(base, path) + + def reset_proxy_mapping(self, base: Module, path: str) -> None: + tracer.proxy_paths[self] = path + tracer.proxy_modules[self] = base + + def __getattr__(self, name: str) -> AttrProxy: + assert isinstance(self, Module) + # Calling into torch.nn.Module.__getattr__ with super(), + # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py. + # which then calls into _ModuleStackTracer.getattr + attr_val = super().__getattr__(name) # type: ignore[misc] + if isinstance(attr_val, AttrProxy): + attr_val = tracer.proxy_modules[attr_val] + elif not isinstance(attr_val, Module): + return attr_val + if attr_val not in tracer.attr_proxy_map: + tracer.attr_proxy_map[attr_val] = AttrProxy( + attr_val, tracer.proxy_paths[self] + "." + name + ) + else: + # NOTE [caching AttrProxy]. Caching ensures a 1-1 mapping between AttrProxy and the actual attr_val. + # 1. We reset the proxy_mapping to solve the diamond shape reference problem: we want to record the + # path as A.B.D instead of A.C.D (the purpose of _ModuleStackTracer). + # 2. Instead of creating a new AttrProxy, we just reset the proxy_mapping of existing one. This is to avoid + # dynamo creating multiple guards for the same attr_val but different AttrProxy when exporting + # a model that calls torch.compile (e.g when a model uses torch.cond.) + tracer.attr_proxy_map[attr_val].reset_proxy_mapping( + attr_val, tracer.proxy_paths[self] + "." + name + ) + return tracer.attr_proxy_map[attr_val] + + def get_base(self) -> Module: + return tracer.proxy_modules[self] + + @property + def _modules(self) -> Dict[str, AttrProxy]: + assert "_modules" in self.__dict__ + submodules = self.__dict__["_modules"] + assert isinstance(submodules, dict) + return { + key: AttrProxy(value, tracer.proxy_paths[self] + "." + str(key)) + for key, value in submodules.items() + } + + self.proxy_type = AttrProxy + + def path_of_module(self, mod: Module) -> str: + """ + Use tracked access path during tracing instead of the default BFS behavior. + Still use all the possible module paths to verify the result. + """ + if mod is self.scope_root: + return "" + + if isinstance(mod, _AttrProxy): + return self.proxy_paths[mod] + + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise _ModuleNotInstalledAsSubmoduleError from e + + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] + ) -> object: + if not isinstance(attr_val, Module) or isinstance(attr_val, fx.GraphModule): + return super().getattr(attr, attr_val, parameter_proxy_cache) + if isinstance(attr_val, _AttrProxy): + return attr_val + + # See NOTE [caching AttrProxy]. + if attr_val not in self.attr_proxy_map: + self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr) + else: + self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr) + return self.attr_proxy_map[attr_val] + + def trace( # type: ignore[override] + self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]] + ) -> fx.Graph: + res = super().trace(root, concrete_args) + + # Since we are making _AttrProxy mimic the original + # submodule, when someone registers a module directly + # to the tracer while tracing, the proxy object gets registered + # first. So we need to replace the proxy modules with the real ones + # This can happen during HOO tracing + proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = [] + for name, module in self.root.named_modules(): + if module in self.proxy_modules: + proxy_module_names_to_be_replaced.append((name, module)) + + def _delete_proxy_attr(obj: Module, target: str) -> bool: + # Copied from fx/graph_module.py + # Customized it for proxy type + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + assert isinstance(obj, Module) + mod = obj + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, (_AttrProxy, Module)): + return False + + if not hasattr(mod, target_submod): + return False + + # At least the leaf module should be proxy type. + if not isinstance(getattr(mod, target_submod), _AttrProxy): + return False + + delattr(mod, target_submod) + return True + + for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced: + _delete_proxy_attr(self.root, proxy_module_name) + actual_module = self.proxy_modules[proxy_module] + _assign_attr(actual_module, self.root, proxy_module_name) + + return res + + def call_module( + self, + m: Module, + forward: Callable, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> None: + """PythonKeyTracer overrides call_module to avoid the scope handling, + but we actually want it. + """ + from torch._dynamo import OptimizedModule + + # FIXME (tmanlaibaatar) + # When we call torch.compile inside HOO, we will end up + # invoking a module that is not registered on the root. For + # now, we just inline them. But once we start supporting + # mark_strict in export, we do need to properly handle this. + # Right now, it doesn't matter because current non-strict + # use cases don't need to work with HOO. + if isinstance(m, (OptimizedModule, GraphModule)): + return forward(*args, **kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except _ModuleNotInstalledAsSubmoduleError as e: + warnings.warn( + f"Unable to find the path of the module {m}. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information." + ) + return forward(*args, **kwargs) + + def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool: + return False + + def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: + """ + Create node and add on metadata. + Add nn_module_stack here instead of TracerBase, + since calls to make_fx() might not want to record module stack metadata. + Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. + Add stack_trace by filtering out forward() stack frames. + """ + node = super().create_node(*args, **kwargs) # type: ignore[arg-type] + + # nn_module_stack + if node.op not in ["placeholder", "output"]: + if "nn_module_stack" not in node.meta: + node.meta["nn_module_stack"] = self.module_stack + # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] + for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): + if isinstance(mod_cls, type): + node.meta["nn_module_stack"][key] = ( + fqn, + mod_cls.__module__ + "." + mod_cls.__qualname__, + ) + + # torch_fn + if ( + node.op == "call_function" + and self.torch_fn_metadata is not None + and "torch_fn" not in node.meta + ): + node.meta["torch_fn"] = ( + f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", + f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", + ) + + # stack_trace + if "stack_trace" not in node.meta and node.op not in ["placeholder", "output"]: + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + # we retain frames from forward() calls, or ops + # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) + stack_trace = [ + frame + for frame in user_frame_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py + # this is hardcoded, but leads to a much cleaner stack trace + stack_trace = [ + frame + for frame in stack_trace + if not ( + frame.filename.endswith("fx/_symbolic_trace.py") + or frame.filename.endswith("export/_trace.py") + ) + ] + if ( + stack_trace + ): # empty list for strict mode, dynamo should handle stack_trace + stack_trace = traceback.StackSummary.from_list(stack_trace) + node.meta["stack_trace"] = "".join(stack_trace.format()).strip() + + return node + + +class _MakefxTracer: + def __init__( + self, + decomposition_table: Optional[Mapping[OpOverload, Callable]], + tracing_mode: str, + _allow_non_fake_inputs: bool, + pre_dispatch: bool, + record_module_stack: bool, + _allow_fake_constant: bool, + _error_on_data_dependent_ops: bool, + ) -> None: + # Configurations that are used to initialize the context managers and their states. + # Should not modify them during tracing. + self.decomposition_table: Dict[OpOverload, Callable] = dict( + decomposition_table or {} + ) + self.decomposition_table.setdefault( + torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel + ) + self.tracing_mode: str = tracing_mode + self._allow_non_fake_inputs: bool = _allow_non_fake_inputs + self.pre_dispatch: bool = pre_dispatch + self.record_module_stack: bool = record_module_stack + self._allow_fake_constant: bool = _allow_fake_constant + self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops + + # All context managers and their states should be initialized before tracing based on the inputs + # and configurations. After tracing, their states should be cleaned except for shape_env. + # Remember to specify how to intialize it from user inputs and from parent tracer whenever + # adding new modes in _MakefxTracer. + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() + self.proxy_function_mode: Union[ + nullcontext, PreDispatchTorchFunctionMode + ] = nullcontext() + self.fx_tracer: Optional[PythonKeyTracer] = None + self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() + self.torch_fn_metadata_mode: Union[ + nullcontext, TorchFunctionMetadataMode + ] = nullcontext() + + def _checkpoint_modes(self) -> List[Any]: + return [ + self.fake_tensor_mode, + self.proxy_mode, + self.proxy_function_mode, + self.fx_tracer, + self.python_dispatcher_mode, + self.torch_fn_metadata_mode, + ] + + def _restore_modes( + self, + prev_fake_tensor_mode: Optional[FakeTensorMode], + prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode], + prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode], + prev_fx_tracer: Optional[PythonKeyTracer], + prev_python_dispatcher_mode: Union[nullcontext, Any], + prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode], + ) -> None: + self.fake_tensor_mode = prev_fake_tensor_mode + self.proxy_mode = prev_proxy_mode + self.proxy_function_mode = prev_proxy_function_mode + self.fx_tracer = prev_fx_tracer + self.python_dispatcher_mode = prev_python_dispatcher_mode + self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode + + @contextmanager + def _init_modes_from_inputs( + self, f: Callable, args: Tuple[object, ...] + ) -> Generator[None, None, None]: + prev_modes = self._checkpoint_modes() + try: + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + + if hasattr(f, "_orig_mod") and self.record_module_stack: + scope_root = f._orig_mod + self.fx_tracer = _ModuleStackTracer(scope_root) + else: + self.fx_tracer = PythonKeyTracer() + + if self.tracing_mode == "fake": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=ShapeEnv(), + static_shapes=True, + ) + self.fake_tensor_mode = fake_tensor_mode + elif self.tracing_mode == "symbolic": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + shape_env = ShapeEnv() + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=shape_env, + ) + assert ( + fake_tensor_mode.shape_env is not None + ), "shape_env should be set if tracing with 'symbolic'" + self.fake_tensor_mode = fake_tensor_mode + else: + if not self.tracing_mode == "real": + raise AssertionError( + f"Unexpected tracing type: {self.tracing_mode}" + ) + + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None: + self.proxy_mode = ProxyTorchDispatchMode( + fx_tracer, + self.tracing_mode, + pre_dispatch=self.pre_dispatch, + _allow_fake_constant=self._allow_fake_constant, + _error_on_data_dependent_ops=self._error_on_data_dependent_ops, + ) + + if self.pre_dispatch: + self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) + + # pre-autograd tracing uses per-dispatch-key modes, + # which requires the python dispatcher + if self.tracing_mode == "symbolic" or self.pre_dispatch: + self.python_dispatcher_mode = enable_python_dispatcher() + + self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) + + @contextmanager + def _init_modes_from_parent( + self, parent_tracer: _MakefxTracer + ) -> Generator[None, None, None]: + # By default, subtracer creates new modes based on parent tracer's config. + # However, there are cases where we want to share the same modes with parent tracer + # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same. + prev_modes = self._checkpoint_modes() + try: + self.fake_tensor_mode = parent_tracer.fake_tensor_mode + + def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: + if type(parent_tracer) == PythonKeyTracer: + return PythonKeyTracer() + elif type(parent_tracer) == _ModuleStackTracer: + return _ModuleStackTracer(parent_tracer.scope_root) + else: + raise RuntimeError( + f"Unexpected tracer type: {type(parent_tracer)}." + ) + + assert parent_tracer.fx_tracer is not None + self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer) + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _trace_inner(self, f: Callable, *args: object) -> GraphModule: + phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) + + def _wrap_fake(args: T) -> T: + arg_count = 0 + + def inner_wrap_fake(x: object) -> object: + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + + assert self.fake_tensor_mode is not None + source = ConstantSource(f"input{arg_count}") + if isinstance(x, Tensor): + arg_count += 1 + return self.fake_tensor_mode.from_tensor(x, source=source) + # NB: don't match on bools + elif type(x) is int and self.tracing_mode == "symbolic": + assert ( + self.fake_tensor_mode.shape_env is not None + ), "shape_env should be set if tracing with 'symbolic'" + return self.fake_tensor_mode.shape_env.create_symintnode( + self.fake_tensor_mode.shape_env.create_symbol( + x, source, positive=None + ), + hint=x, + source=source, + ) + elif isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + self.fake_tensor_mode, x + ) + + assert not isinstance( + x, FakeScriptObject + ), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + return x + + wrap_fn_map = { + "real": lambda x: x, + "fake": inner_wrap_fake, + "symbolic": inner_wrap_fake, + } + return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) + + def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: + if ( + not hasattr(inspect.unwrap(f), "__code__") + or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS + ): + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + return fake_signature(f, len(phs)) + return f + + args = _wrap_fake(args) + func = _wrap_func(f, phs) + # We disable the autocast cache as the autocast cache causes type conversions on parameters to + # check a cache, which introduces untracked tensors into the graph + # + # We also disable tracing by any other tensor proxy-based tracers except the current. The + # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is + # thus irrelevant to any external functional trace. + proxy_mode: ProxyTorchDispatchMode = typing.cast( + ProxyTorchDispatchMode, self.proxy_mode + ) + with ExitStack() as stack: + stack.enter_context(decompose(self.decomposition_table)) + if self.fake_tensor_mode: + stack.enter_context(self.fake_tensor_mode) + stack.enter_context(self.python_dispatcher_mode) + stack.enter_context(self.proxy_function_mode) + stack.enter_context(self.torch_fn_metadata_mode) + stack.enter_context(proxy_mode) + stack.enter_context(disable_autocast_cache()) + stack.enter_context(_set_make_fx_tracer(self)) + + assert self.fx_tracer is not None + t = dispatch_trace( + wrap_key(func, args, self.fx_tracer, self.pre_dispatch), + tracer=self.fx_tracer, + concrete_args=tuple(phs), + ) + + # TODO: kind of a bad way to do it, should maybe figure out a better way + if self.tracing_mode == "symbolic": + assert self.fake_tensor_mode is not None + t.shape_env = self.fake_tensor_mode.shape_env + return t + + def trace(self, f: Callable, *args: object) -> fx.GraphModule: + with self._init_modes_from_inputs(f, args): + return self._trace_inner(f, *args) + + def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: + # Create a new tracer based on parent's config + sub_tracer = _MakefxTracer( + self.decomposition_table, + "real", + self._allow_non_fake_inputs, + self.pre_dispatch, + self.record_module_stack, + self._allow_fake_constant, + self._error_on_data_dependent_ops, + ) + with sub_tracer._init_modes_from_parent(self): + return sub_tracer._trace_inner(f, *args) + + +_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None + + +@contextmanager +def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: + global _CURRENT_MAKE_FX_TRACER + prev_tracer = _CURRENT_MAKE_FX_TRACER + try: + _CURRENT_MAKE_FX_TRACER = tracer + yield + finally: + _CURRENT_MAKE_FX_TRACER = prev_tracer + + +def make_fx( + f: Callable, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + tracing_mode: str = "real", + _allow_non_fake_inputs: bool = False, + *, + pre_dispatch: bool = False, + record_module_stack: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, +) -> Callable[..., GraphModule]: + """ + Given a function f, return a new function which when executed with valid + arguments to f, returns an FX GraphModule representing the set of operations that + were executed during the course of execution. + """ + + assert tracing_mode in ["real", "fake", "symbolic"] + + make_fx_tracer = _MakefxTracer( + decomposition_table, + tracing_mode, + _allow_non_fake_inputs, + pre_dispatch, + record_module_stack, + _allow_fake_constant, + _error_on_data_dependent_ops, + ) + + @functools.wraps(f) + def wrapped(*args: object) -> GraphModule: + return make_fx_tracer.trace(f, *args) + + return wrapped + + +def get_torch_dispatch_modes() -> List[TorchDispatchMode]: + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() + + +# TODO: this is a legacy name, there is only ever one proxy mode as it's an +# infra mode +def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + return get_proxy_mode() + + +def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + """ + Current the currently active proxy tracing mode, or None if + we are not currently tracing. This includes pre-dispatch proxy + tracing. + """ + pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.PROXY + ) + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + assert ( + pre_dispatch_mode is None or mode is None + ), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + return pre_dispatch_mode or mode + + +def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R: + """ + Call into the currently active proxy tracing mode to do a + SymInt/SymFloat/SymBool dispatch trace on a function that operates on + these arguments. + """ + mode = get_proxy_mode() + assert mode + # Have to do it manually, because we're not doing the normal torch + # dispatch machinery which disables it for us + with disable_proxy_modes_tracing(): + # TODO: properly compute types + types: List[Type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] + + +@contextmanager +def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]: + return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + + +def maybe_handle_decomp( + proxy_mode: ProxyTorchDispatchMode, + op: OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> object: + if op in CURRENT_DECOMPOSITION_TABLE: + with proxy_mode: + proxy_mode.decomp_layers += 1 + out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + proxy_mode.decomp_layers -= 1 + return out + + return NotImplemented + + +def get_isolated_graphmodule( + func: Callable, + args: Tuple[object, ...], + kwargs: Dict[str, object], + tracing_mode: str = "real", + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, +) -> GraphModule: + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) + + with disable_proxy_modes_tracing(): + gm = make_fx( + wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode + )(all_args) + return gm + + +def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: + """A helper function for setting up unbacked_bindings on the destination FX graph.""" + from .symbolic_shapes import compute_unbacked_bindings + + # Can't use detect_fake_mode here, + # + # python test/distributed/_tensor/test_dtensor_compile.py -k + # test_tp_compile_fullgraph_is_seq_parallel_False + # + # will fail. Very strange, it probably isn't right for them to be using + # two fake modes there... + fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + if fake_mode and fake_mode.shape_env: + if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): + assert isinstance(out_proxy, Proxy), out_proxy + out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6410be41c40308e4947bb97311eb64d3c41aa3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/recording.py @@ -0,0 +1,512 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import itertools +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree + + +log = logging.getLogger(__name__) +trace_shape_events_log = torch._logging.getArtifactLogger( + __name__, "trace_shape_events" +) + + +__all__ = [ + "ShapeEnvEvent", + "record_shapeenv_event", + "replay_shape_env_events", + "FakeTensorMeta", + "shape_env_check_state_equal", + "NotEqualError", +] + +# [Note: Recording ShapeEnv Events] +# ================================= +# +# What is a ShapeEnv event? +# ------------------------- +# We consider a ShapeEnv event every function call (ShapeEnv method or +# independent function) that modifies the state of the ShapeEnv instance. +# Such calls are recorded alongside their positional and keyword arguments, +# so that it may be replayed over a different ShapeEnv instance. +# +# See [Note: ShapeEnv State Equality] for what is considered the state +# of a ShapeEnv instance. +# +# What is it for? +# --------------- +# ShapeEnv events recording is used for reconstructing the ShapeEnv in an +# arbitrary state in time. +# +# Being able to arbitrarily replay events like so is useful, mainly for +# translation validation bisection. i.e. if a ValidationException has been +# raised, find the earliest point in time where the translation validation +# fails. +# +# Besides that, it also allows us to inspect the given instance and, +# for example, check the guards that would actually be issued at that point. +# +# What kind of arguments can be stored in an event? +# ------------------------------------------------- +# There's no specific rule for what cannot be used as an argument. +# That said, pay special attention to the following cases: +# +# 1. Tensor inputs: there are some tests that check whether the inputs +# were garbage collected after execution. These will fail if there's +# an event that is holding a reference to those inputs. +# +# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that +# will be automatically replaced by the new given ShapeEnv instance. +# +# 3. SymTypes arguments: they also hold references to ShapeEnv. So, +# whenever we see them, we create a new instance, replacing the +# ShapeEnv reference. +# +# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic +# shapes. That argument must be replaced when replaying the event at +# ShapeEnvEvent.run, since it has to reference a node from the given +# instance, and not from the recorded instance. + + +# Event class for reconstructing ShapeEnv at arbitrary time. +# +# Represents a method call that mutates ShapeEnv in a way that affects the +# issued guards, when ShapeEnv.produce_guards is called. +@dataclass +class ShapeEnvEvent: + # ShapeEnv method. + f: Callable + + # Arguments and keyword arguments called with. + args: Optional[List[Any]] = None + kwargs: Optional[Dict[str, Any]] = None + + # List of tracked_fakes at the time the method was called. + tracked_fakes: Optional[List[Any]] = None + + # Name of the captured event. + # Used for special handling of particular methods. + name: Optional[str] = None + + # Replay itself, but using shape_env as self. + def run(self, shape_env=None) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + is_symbolic, + ShapeEnv, + SymTypes, + ) + + # Special handling for the constructor event. + if self.f is ShapeEnv: + assert shape_env is None and self.args is None and self.kwargs is not None + return ShapeEnv(**self.kwargs) + + assert shape_env is not None + args = list(self.args or []) + kwargs = dict(self.kwargs or {}) + + # Replace any argument of type ShapeEnv by the given one. + args, kwargs = pytree.tree_map_only( + ShapeEnv, lambda _: shape_env, (args, kwargs) + ) + + # Replace any argument of type SymTypes by a new instance, + # replacing its ShapeEnv reference. + args, kwargs = pytree.tree_map_only( + lambda x: isinstance(x, SymTypes) and is_symbolic(x), + lambda a: type(a)(a.node.with_shape_env(shape_env)), + (args, kwargs), + ) + + # Converts FX nodes using the mapping argument. + def maybe_convert_node(x: Any) -> Any: + if not isinstance(x, torch.fx.Node): + # Don't do anything to x if it's not an FX node. + return x + + # If, at some point, we created an FX node, it means that translation validation is on. + # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and + # we are tracking node names at shape_env.name_to_node. + assert hasattr(shape_env, "name_to_node") + name_to_node = shape_env.name_to_node # type: ignore[attr-defined] + assert x.name in name_to_node + return name_to_node[x.name] + + # Replaces the value of an specific argument by the result of fn. + def replacearg(index: int, key: str, fn: Callable): + if index < len(args): + args[index] = fn(args[index]) + if key in kwargs: + kwargs[key] = fn(kwargs[key]) + + if self.is_create_fx_call_function(): + # ShapeEnv.create_fx_call_function: + # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. + # They must be replaced, since a "call_function" FX node with this tuple as argument + # will be added to the FX graph of the new shape_env. + replacearg( + index=2, + key="args", + fn=lambda args: tuple(maybe_convert_node(a) for a in args), + ) + if self.is_evaluate_expr() or self.is_defer_runtime_assert(): + # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert: + # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. + # They must be replaced, since it will be part of a "call_function" FX node for + # torch._assert, which will be added to the FX graph of the new shape_env. + replacearg(index=3, key="fx_node", fn=maybe_convert_node) + + # Actually call the method with the converted arguments. + return self.f(*args, **kwargs) + + def __str__(self) -> str: + name = self.name if self.name is not None else self.f.__name__ + return f"event: {name} ({self.args}, {self.kwargs})" + + def is_create_fx_call_function(self) -> bool: + return self.name == "_create_fx_call_function" + + def is_evaluate_expr(self) -> bool: + return self.name == "evaluate_expr" + + def is_defer_runtime_assert(self) -> bool: + return self.name == "defer_runtime_assert" + + +NEST = 0 + + +# Extracts a ShapeEnv instance inside args and kwargs. +# Specifically, it looks for: +# 1. ShapeEnv arguments +# 2. SymInt, SymFloat, or SymBool arguments +# If we find more than one object of any of the above types, we +# also check that the ShapeEnv instance is the same for all of them. +def _extract_shape_env_and_assert_equal(args, kwargs): + from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes + + def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: + if old is not None: + assert old is new, "call with different ShapeEnv" + return new + + shape_env = None + for val in itertools.chain(args, kwargs.values()): + if isinstance(val, ShapeEnv): + shape_env = assert_equal(shape_env, val) + if isinstance(val, SymTypes) and is_symbolic(val): + shape_env = assert_equal(shape_env, val.node.shape_env) + + return shape_env + + +# Decorator for recording the given function as a replayable event. +# +# This decorator should be used at every function that mutates the state of +# ShapeEnv in some way that affects the resulting issued guards (i.e. when +# ShapeEnv.produce_guards is called). +# +# save_tracked_fakes: saves a snapshot of the TrackedFake list. +# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. +# +# When to save the list of TrackedFake? +# ===================================== +# We should save the list of TrackedFake whenever the translation validation +# bisection may actually stop and call the produce_guards method at the moment +# right after the recorded function was played. In other words, since the +# bisection bisects through torch._assert calls, we should save in all methods +# that adds a torch._assert call to the symbolic shapes FX graph. +# +# At the moment, there are 2 methods that save the list: +# - ShapeEnv.evaluate_expr +# - ShapeEnv.defer_runtime_assert +def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: + def decorator(fn: Callable) -> Callable: + assert callable(fn) + args = inspect.getfullargspec(fn).args + assert args and args[0] == "self", ( + "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " + "code so that it calls into a method on ShapeEnv" + ) + name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + assert isinstance(args[0], ShapeEnv) + + global NEST + + trace_shape_events_log.debug( + "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs + ) + NEST += 1 + + def retlog(r): + trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) + return r + + try: + if args[0].is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return retlog(fn(*args, **kwargs)) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return retlog(fn(*args, **kwargs)) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, list(args), kwargs, tracked_fakes, name=fn.__name__ + ) + # Play the event on this ShapeEnv. + # NB: It's important to put the event first, because running + # the event can trigger internal events that must be ordered + # after this event. However, if an exception happens, we do + # NOT want to have the event in the list, so pop it off from + # the record if an error happened + self.events.append(event) + try: + return retlog(event.run(self)) + except Exception: + self.events.pop() + raise + + except Exception: + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) + raise + + finally: + NEST -= 1 + + return wrapper + + return decorator + + +# Replays the ShapeEnvEvents list. +# It assumes the first event is the constructor call. +# +# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. +def replay_shape_env_events(events): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + constructor_event = events[0] + assert constructor_event.f == ShapeEnv + + # Constructs the new ShapeEnv. + shape_env = constructor_event.run() + + for event in events[1:]: + try: + # Actually replays each event. + # We need to call create_mapping_fn every time, since the node list might + # change after each event is replayed. + event.run(shape_env) + except Exception as e: + log.error("failed when running event: %s", event) + raise + + return shape_env + + +# FakeTensor metadata. +# This is to be used in place of FakeTensor placeholders when calling +# ShapeEnv.produce_guards. +@dataclass +class FakeTensorMeta: + tensor_size: Tuple[Union[int, torch.SymInt], ...] + tensor_stride: Tuple[Union[int, torch.SymInt], ...] + tensor_storage_offset: Union[int, torch.SymInt] + is_nested: bool + + def size(self) -> Tuple[Union[int, torch.SymInt], ...]: + return self.tensor_size + + def stride(self) -> Tuple[Union[int, torch.SymInt], ...]: + return self.tensor_stride + + def storage_offset(self) -> Union[int, torch.SymInt]: + return self.tensor_storage_offset + + def dim(self) -> int: + return len(self.tensor_size) + + @staticmethod + def from_fake(fake) -> "FakeTensorMeta": + return FakeTensorMeta( + fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested + ) + + +# [Note: ShapeEnv State Equality] +# =============================== +# +# What is considered ShapeEnv state? +# ---------------------------------- +# We consider to be the state of a ShapeEnv instance everything that +# is not in the inline tuple inside remove_nonstate_variables function. +# That is: the fields within ShapeEnv that modify the flow of execution +# of the program. +# +# So, for example: the replacements field might influence on how an +# expression is simplified. That, in turn, may result in a guard being +# statically known (i.e. not added). +# +# On the other hand, var_to_stack serves only changes what is printed +# in the screen, i.e. used only for debugging purposes. Therefore, we +# should not consider it when comparing states. +# +# What to do on NotEqualError? +# ---------------------------- +# Here are a few possible causes for getting a NotEqualError raised: +# +# 1. New field that does not belong in the ShapeEnv state. +# For example: log field of type ShapeEnvLoggerAdapter. Different +# ShapeEnv instances will always have different ShapeEnvLoggerAdapter +# instances, i.e. equality comparison would fail. +# Solution: add it to the inlined tuple inside remove_nonstate_variables +# function inside check_equal method. +# +# 2. New field that is not directly comparable across instances. +# For example: guards field of type List[ShapeGuard]. More specifically, +# the ShapeGuard type holds an expression and a stack information +# for debugging purposes. When replaying the even on a new ShapeEnv +# instance, the stack would be different, which would trigger this error. +# Solution: add a special case to the map_value function inside +# check_equal function. +# +# 3. Mutation of ShapeEnv on some not recorded function. +# If a mutation of the state of ShapeEnv happens inside a function +# that is not recorded (or that no caller in the stack is recorded), +# then, the replayed ShapeEnv won't catch that. +# Solution: decorate the function with record_shape_env_event. + + +# Checks whether the state of two ShapeEnv are equal w.r.t. the guards +# returned by ShapeEnv.produce_guards. +def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): + # Collect and remove variables that don't necessarily represent the state + # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the + # instance itself. + env1_vars = vars(env1).copy() + env2_vars = vars(env2).copy() + + for v in non_state_variable_names: + if v in env1_vars: + env1_vars.pop(v) + if v in env2_vars: + env2_vars.pop(v) + + # Function for transforming the mismatched values into string. + # Needed, since dict and set entries order might not be the same every time. + def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return ( + "{" + + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + + "}" + ) + if isinstance(value, set): + return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" + return str(value) + + # Compares env1_vars with env2_vars. + # Here, we allow the value of each field to be mapped, so that we appropriately + # compare the two values. + def compare_vars( + map_value: Callable[[str, Any], Any] + ) -> List[Tuple[str, str, str]]: + env1_set, env2_set = set(env1_vars), set(env2_vars) + + # First, compare the set of keys in each vars dictionary. + if env1_set != env2_set: + raise NotEqualError( + "field set mismatch:", + [ + ( + "found unique fields:", + str(sorted(env1_set - env2_set)), + str(sorted(env2_set - env1_set)), + ), + ], + ) + + # Then, sort the keys, and compare the mapped values of each key. + sorted_keys = list(env1_set) + sorted_keys.sort() + + mapped_dict = [ + (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) + for k in sorted_keys + ] + + # Return a list of tuples representing the fields that did not match + # alongside their respective mapped values. + return [ + (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) + for k, val1, val2 in mapped_dict + if val1 != val2 + ] + + # Accumulate the mismatching fields. + errors = compare_vars(map_value) + + if len(errors) > 0: + raise NotEqualError("field values don't match:", errors) + + +class NotEqualError(Exception): + def __init__( + self, + msg: str, + mismatched: List[Tuple[str, str, str]], + ) -> None: + details = "\n".join( + [ + "\n".join( + [ + f"==> {inner_msg}", + f" > Left: {str1}", + f" > Right: {str2}", + ] + ) + for inner_msg, str1, str2 in mismatched + ] + ) + + super().__init__( + f"""\ +ShapeEnv not equal: {msg} + +{details} +""" + ) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a33ddf3710a4a77c3fa5e13fd69f0570b856a8a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/refinement_types.py @@ -0,0 +1,17 @@ +# mypy: allow-untyped-defs +class Equality: + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def __str__(self): + return f'{self.lhs} = {self.rhs}' + + def __repr__(self): + return f'{self.lhs} = {self.rhs}' + + def __eq__(self, other): + if isinstance(other, Equality): + return self.lhs == other.lhs and self.rhs == other.rhs + else: + return False diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..3647ca59153b4fa695f76d1718dfac51ee553609 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/rewriter.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import ast +import inspect +import textwrap +import copy +import functools +from types import FunctionType +from typing import cast, Union, Callable, Dict, Optional, Any +from torch.fx._symbolic_trace import Tracer +from torch.fx.graph import Graph +from torch._sources import normalize_source_lines +import torch + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + # This function checks for new keys added in the globals dict. TorchDynamo + # can insert new keys in the global dict and upset the check. Therefore, put + # a disable here. This function is an optimization pass and not really + # suitable for dynamo tracing anyways. + @torch._dynamo.disable + def rewrite(self, fn: FunctionType): + + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = ''.join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] + return g + # Return the correct FunctionType object + return change_func_globals(fn_compiled, globals=fn.__globals__) + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse('torch._assert()', mode='eval') + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_AnnAssign(self, node): + """ + Swap out Python's AnnAssign with an Assign node where the annotation function is called. + Example: + Original: + y: Tensor_Type(1,2,3, Dyn) = f2(x) + Output: + y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) + """ + return ast.Assign(targets=[node.target], value=ast.Call( + func=ast.Name(id='annotate', ctx=ast.Load()), + args=[node.value, node.annotation], keywords=[])) + + +class RewritingTracer(Tracer): + def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + return super().trace(_rewrite(root), concrete_args) + + +def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's `forward` as well as the `forward`s of + # all of this module's recursive descendents. Return the new, + # rewritten module hierarchy. + def rewrite_module(m : torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + for k, v in orig.__dict__.items(): + if isinstance(v, torch.nn.Module): + self.__dict__[k] = copy.copy(rewrite_module(v)) + else: + self.__dict__[k] = copy.copy(v) + RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) + return RewrittenModule(m) + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7ab78706cb913d054fe36be7b9a17c2c8a46c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/schema_type_annotation.py @@ -0,0 +1,112 @@ +# mypy: allow-untyped-defs +import torch +import torch.fx +import inspect +from typing import Any, Dict, Optional, Tuple +from torch.fx.node import Argument, Target +from torch._jit_internal import boolean_dispatched +from torch.fx.operator_schemas import _torchscript_type_to_python_type + +from torch.fx import Transformer + +class AnnotateTypesWithSchema(Transformer): + """ + Use Python function signatures to annotate types for `Nodes` within an FX graph. + This pulls out Python function signatures for: + + 1. Standard `torch.nn` Module calls + 2. `torch.nn.functional` calls + 3. Attribute fetches via `get_attr` + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = AnnotateTypesWithSchema(traced).transform() + + """ + def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, + annotate_modules : bool = True, annotate_get_attrs : bool = True): + super().__init__(module) + self.annotate_functionals = annotate_functionals + self.annotate_modules = annotate_modules + self.annotate_get_attrs = annotate_get_attrs + + def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + python_ret_type = None + if self.annotate_functionals and target.__module__ == 'torch.nn.functional': + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched['if_true'], dispatched['if_false'] + # TODO: can we emit the union of these? What are the implications on TorchScript + # compilation? + if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: + return super().call_function(target, args, kwargs) + target_for_analysis = if_true + + python_ret_type = self._extract_python_return_type(target_for_analysis) + + return_proxy = super().call_function(target, args, kwargs) + return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return return_proxy + + def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + python_ret_type = None + assert isinstance(target, str) + submod = self.fetch_attr(target) + if self.annotate_modules and hasattr(submod.__class__, '__name__'): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + python_ret_type = self._extract_python_return_type(submod.forward) + return_proxy = super().call_module(target, args, kwargs) + return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return return_proxy + + def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + attr_proxy = super().get_attr(target, args, kwargs) + + if self.annotate_get_attrs: + module_itr = self.module + assert isinstance(target, str) + atoms = target.split('.') + for i, atom in enumerate(atoms): + if not hasattr(module_itr, atom): + raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') + module_itr = getattr(module_itr, atom) + + maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) + if maybe_inferred_ts_type.success(): + python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) + attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type + + return attr_proxy + + def _extract_python_return_type(self, target : Target) -> Optional[Any]: + """ + Given a Python call target, try to extract the Python return annotation + if it is available, otherwise return None + + Args: + + target (Callable): Python callable to get return annotation for + + Returns: + + Optional[Any]: Return annotation from the `target`, or None if it was + not available. + """ + assert callable(target) + try: + sig = inspect.signature(target) + except (ValueError, TypeError): + return None + + return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000000000000000000000000000000..268f56e3214fb99e6344af127006a8b7e34c9c00 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/sym_node.py @@ -0,0 +1,1554 @@ +# mypy: allow-untyped-defs +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + +import builtins +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache, update_wrapper +from typing import Optional, Type, TYPE_CHECKING, Union + +import torch + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) + + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) +sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") + + +__all__ = ["SymNode", "method_to_operator", "magic_methods"] + + +from torch.types import py_sym_types as SymTypes + + +def _to_symtype(t): + if t is bool: + return SymBool + if t is int: + return SymInt + if t is float: + return SymFloat + return t + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float, bool]], + constant=None, + fx_node=None, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) + # in hopes that we've learned enough about the unbacked symints to + # discharge the hint; otherwise, you're likely to just error out. + # + # (A previous version of this system had some optimizations to only + # recompute when it was possible we had learned enough about the + # unbacked symint that a hint was now possible, but as we added more + # potential refinements to unbacked symints this got harder to keep + # in sync, so we've deleted it for now.) + + def compute_hint(): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + # This occasionally gets exercised by, e.g., + # convert_shape_to_symint. It's just a nicety so you don't HAVE + # to have a correct hint on hand when making a SymNode. + # Don't attempt to compute for unbacked, this can be quite + # expensive. + if free_unbacked_symbols(self.expr): + return None + hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) + if hint is not None: + hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint + return hint + + if hint is not None: + assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( + "Cannot create SymNode of type " + f"{pytype} with incompatible hint of type {type(hint)}" + ) + if self.shape_env and self.shape_env._translation_validation_enabled: + # This is technically not TV, but this assert is expensive so + # let's only do it when we're already doing expensive things + computed_hint = compute_hint() + assert ( + hint == computed_hint + ), f"{hint} != {computed_hint} (for {self.expr})" + else: + hint = compute_hint() + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + tx_validation_en = ( + self.shape_env and self.shape_env._translation_validation_enabled + ) + self.fx_node = tx_validation_en and fx_node + + def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def _value_eq(self, other: "SymNode") -> bool: + # Purposely don't include the shape_env in the eq. + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + + def _value_hash(self) -> int: + # Purposely don't include the shape_env in the hash. + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + @property + def hint(self): + return self._hint + + def has_hint(self): + return self._hint is not None + + def require_hint(self, fallback=None): + if self._hint is None: + if fallback is not None: + return fallback + # NB: we expect this to raise + return self.shape_env.size_hint(self.expr) + return self._hint + + def maybe_as_int(self): + if self.expr.is_number: + return int(self.expr) + else: + return None + + # NB: This does conversions, not sure if this is good or not + def maybe_as_float(self): + import sympy + + if isinstance(self.expr, sympy.Float): + return float(self.expr) + else: + return None + + def maybe_as_bool(self): + import sympy + + if self.expr is sympy.true: + return True + elif self.expr is sympy.false: + return False + else: + return None + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def is_nested_int(self): + # Unbacked SymInts cannot be nested int today + return ( + self._hint is not None + and isinstance(self._hint, SymInt) + and self._hint.node.is_nested_int() + ) + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + rep = [ + f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", + ] + if self._hint is not None: + rep.append(f"hint={self._hint}") + if self.constant is not None: + rep.append(f"constant={self.constant}") + if self.fx_node is not None: + rep.append(f"fx_node={self.fx_node}") + return ", ".join(rep) + ")" + + def _graph_repr(self) -> builtins.str: + # Representation used by GraphModule to create a pythonic version of a graph + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> "SymNode": + return self._abs() # type: ignore[attr-defined] + + def pos(self) -> "SymNode": + return self._pos() # type: ignore[attr-defined] + + def round(self, ndigits=None) -> "SymNode": + return self._round(ndigits) # type: ignore[attr-defined] + + def trunc(self) -> "SymNode": + return self._trunc() # type: ignore[attr-defined] + + def add(self, other) -> "SymNode": + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> "SymNode": + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> "SymNode": + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> "SymNode": + return self._mod(other) # type: ignore[attr-defined] + + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] + + def and_(self, other) -> "SymNode": + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> "SymNode": + return self._or_(other) # type: ignore[attr-defined] + + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] + + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> "SymNode": + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> "SymNode": + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> "SymNode": # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> "SymNode": + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> "SymNode": + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> "SymNode": + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> "SymNode": + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> "SymNode": + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> "SymNode": + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> "SymNode": + return self._floor() # type: ignore[attr-defined] + + def is_integer(self) -> "SymNode": + return self._is_integer() # type: ignore[attr-defined] + + def sym_float(self) -> "SymNode": # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> "SymNode": + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> "SymNode": + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> "SymNode": + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> "SymNode": # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> "SymNode": # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> "SymNode": + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> "SymNode": + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if ( + self.has_hint() + and not free_unbacked_symbols(self.expr) + and not self.shape_env.prefer_deferred_runtime_asserts_over_guards + ): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def expect_size(self, file, line): + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + b = self.ge(self.wrap_int(0)) + # Generate a deferred runtime assert + r = b.expect_true(file, line) + # Refine compile time range, but only if it's unbacked. + # If you refine range for hinted variables, you can end up making + # improper deductions since compile time reasoning may be + # incompatible with runtime reasoning. + if r and not self.has_hint(): + _advise_is_size(SymInt(self)) + return r + + def guard_size_oblivious(self, file, line): + """ + Like guard_bool, but if we encounter unbacked symbols, if those symbols + are size-like, we will treat them as >= 2 for the purposes of the analysis. + + This CHANGES the runtime semantics, but all size-oblivious sites have been + audited to ensure that the runtime semantics don't change in a material way. + Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping + an unbacked one size, or a tensor reporting as non-contiguous even if it's + contiguous if it would have been reported contiguous due to being empty. + """ + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.shape_env.evaluate_expr( + self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True + ) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def nested_int(self): + return None + + def is_constant(self): + return False + + +# TODO: this probably needs the sizes-strides eval functions +METHOD_TO_OPERATOR = { + "pos": operator.pos, + "abs": operator.abs, + "add": operator.add, + "and": operator.and_, + "ceil": math.ceil, + "eq": operator.eq, + "floor": math.floor, + "trunc": math.trunc, + "int_floordiv": operator.floordiv, + "ge": operator.ge, + "gt": operator.gt, + "is_integer": lambda x: x.is_integer(), + "le": operator.le, + "lshift": operator.lshift, + "lt": operator.lt, + "mod": operator.mod, + "mul": operator.mul, + "ne": operator.ne, + "neg": operator.neg, + "or": operator.or_, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, + "round": builtins.round, + "rshift": operator.rshift, + "sub": operator.sub, + "sym_float": sym_float, + "sym_ite": sym_ite, + "sym_max": sym_max, + "sym_min": sym_min, + "sym_not": sym_not, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, +} + +unary_magic_methods = { + "abs", + "sym_float", + "sym_int", + "ceil", + "floor", + "neg", + "sym_not", + "pos", + "trunc", +} + + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + +# Unary methods that are not magic methods +unary_nonmagic_methods = { + "is_integer", +} + +unary_methods = unary_magic_methods | unary_nonmagic_methods + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that implicitly convert SymBool into SymInt +bool_becomes_int_magic_methods = {"add", "sub", "mul"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +# Methods that are only for float +only_float_magic_methods = {"is_integer", "round", "sym_int"} + + +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} + + +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", + "is_integer", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv + + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + + +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + + return PowByNatural(a, b) + + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +reflectable_magic_methods = { + "add": operator.add, + "sub": operator.sub, + "mul": operator.mul, + "mod": _sympy_mod, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, + "and": _sympy_and, + "or": _sympy_or, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + from torch.utils._sympy.functions import FloorToInt + + return FloorToInt(a) + + +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) +def _sympy_trunc(a): + from torch.utils._sympy.functions import TruncToInt + + return TruncToInt(a) + + +def _sympy_ceil(a): + from torch.utils._sympy.functions import CeilToInt + + return CeilToInt(a) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + from torch.utils._sympy.functions import Min + + return Min(a, b) + + +def _sympy_max(a, b): + from torch.utils._sympy.functions import Max + + return Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import torch.utils._sympy.functions + + return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) + + return fn + + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name # type: ignore[possibly-undefined] + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +def _sympy_round(number, ndigits=None): + from torch.utils._sympy.functions import RoundDecimal, RoundToInt + + if ndigits is None: + return RoundToInt(number) + else: + return RoundDecimal(number, ndigits) + + +def _sympy_sym_float(a): + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) + + +def _sympy_is_integer(a): + import sympy + + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": operator.invert, + "pos": operator.pos, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "trunc": _sympy_trunc, + "sym_float": _sympy_sym_float, + "ceil": _sympy_ceil, + "neg": operator.neg, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "abs": _sympy_abs, + "round": _sympy_round, + "is_integer": _sympy_is_integer, +} + + +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.Integer(1) + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + from torch.utils._sympy.functions import Max + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.Integer(0) + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + +alternate_impl_if_hinted_methods = { + "sym_min": builtins.min, + "sym_max": builtins.max, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + return METHOD_TO_OPERATOR[method] + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def binary_magic_impl(self, other): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + alternate_impl = alternate_impl_if_hinted_methods.get(method) + if alternate_impl and out_hint is not None: + return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) + + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + try: + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + out = safe_expand(out) + sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) + pytype: Type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + if ( + pytype is not None + and out_hint is not None + and not isinstance(out_hint, SymTypes) + ): + out_hint = pytype(out_hint) + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env._create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + def unary_magic_impl(self): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = method_to_operator(method) + if get_proxy_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + sym_node_log.debug("%s %s -> %s", func, expr, out) + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + out = safe_expand(out) + pytype: Type + if method in always_int_magic_methods: + pytype = int + elif method in always_bool_magic_methods: + pytype = bool + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + from torch.fx.experimental.symbolic_shapes import safe_expand + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if get_proxy_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + out = safe_expand(out) + fx_node, _ = pred_node.shape_env._create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + elif method == "round": + + def round_impl(self, ndigits=None): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + from torch.fx.experimental.symbolic_shapes import safe_expand + + op = builtins.round + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) + ) + + expr = self.expr + try: + out = func(expr, ndigits) + except Exception: + log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) + raise + + out = safe_expand(out) + + if ndigits is None: + pytype = int + else: + pytype = self.pytype + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint, ndigits) + + # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the + # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here + # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The + # hack down below works, because all round function down the line all take ndigits=None as default in their + # signature. + # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL + args = [self.fx_node] + if ndigits is not None: + args.append(ndigits) + fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + setattr(SymNode, f"_{method_attr}", round_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = getattr(sys.modules[__name__], method) + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: Type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"sym_{method}" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + + if method in bool_becomes_int_magic_methods: + + def promote(x): + """Implements True+True=2, which works in python but not sympy""" + if isinstance(x, SymBool): + return SymInt(x.node.wrap_int(int(x))) + return x + + else: + + def promote(x): + return x + + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + self = promote(self) + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + sym_node_log.debug("MAGIC %s %s %s", method, self, other) + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + elif method in unary_nonmagic_methods: + orig = getattr(user_type, method) + setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) + elif method == "round": + + def round_magic_impl(self, ndigits=None): + if is_constant(self): + return builtins.round(get_constant(self), ndigits) + + return wrap_node(getattr(self.node, method)(ndigits)) + + setattr(user_type, f"__{method}__", round_magic_impl) + else: + setattr(user_type, f"__{method}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method}__", rbinary_magic_impl) + + +for method, func in magic_methods.items(): # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in only_float_magic_methods: + _make_user_magic(method, SymFloat) + continue + if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..369708d2c2dbc0c886773ad3d7fbf42c63cfd8f9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,5603 @@ +# mypy: ignore-errors + +""" +``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with +our symbolic shapes reasoning system that is used heavily in torch.compile. Although +this is not generally considered public API, when writing framework code in PyTorch +as well as extensions to PyTorch (e.g., in custom operator implementations), you may +need to make use of these APIs to setup dynamic shapes support appropriately. +""" + +import builtins +import collections +import functools +import inspect +import itertools +import logging +import math +import operator +import os +import re +import sys +import threading +import traceback +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum +import atexit +from typing import ( + Any, + cast, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, + TYPE_CHECKING +) +from typing_extensions import TypeAlias + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch.fx.experimental import _config as config + +from torch.fx.experimental.recording import ( + FakeTensorMeta, + ShapeEnvEvent, + record_shapeenv_event, + replay_shape_env_events, + shape_env_check_state_equal +) +from torch.fx.experimental.sym_node import SymNode, SymTypes +from torch._logging import trace_structured, structured + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import SymBool, SymFloat, SymInt +from torch._guards import ShapeGuard, Source, TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.functions import ( + Application, FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt +) +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._traceback import format_frame, CapturedTraceback +from torch._utils_internal import signpost_event +from torch._subclasses.meta_utils import is_sparse_any +import torch.utils._pytree as pytree +from torch.utils._sympy.symbol import SymT, make_symbol, symbol_is_type + +from torch._logging import LazyString + +if TYPE_CHECKING: + from torch._dynamo.source import TensorPropertySource + +InputList = List +DimList = List + +log = logging.getLogger(__name__) + +import sympy +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import precedence, PRECEDENCE + +class GuardOnDataDependentSymNode(RuntimeError): + cond: sympy.Expr + + def __init__(self, cond, *args): + super().__init__(*args) + self.cond = cond + +class PendingUnbackedSymbolNotFound(RuntimeError): + pass + +aten = torch._ops.ops.aten # type: ignore[has-type] + +__all__ = [ + "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int", + "guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr", + "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", + "is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY", + "has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext", + "StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true", + "guard_size_oblivious", "check_consistent", + "compute_unbacked_bindings", "ConvertIntKey", + "rebind_unbacked", "resolve_unbacked_bindings", "is_accessor_node", +] + +# FX node metadata keys for symbolic shape FX graph. +SHAPEENV_EVENT_KEY = "shapeenv_event" +CURRENT_NODE_KEY = "current_node" + + +def log_lru_cache_stats(wrapped_f): + log.debug("lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info()) + + +# Wrapper on lru_cache that reports statistics at process end +def lru_cache(maxsize): + def inner(f): + wrapped_f = functools.lru_cache(maxsize)(f) + old_cache_clear = wrapped_f.cache_clear + prev_hits = 0 + prev_misses = 0 + + # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info + # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not + # weakref'able on some versions of Python + + def cumulative_cache_info(): + cur = wrapped_f.cache_info() + return functools._CacheInfo( + prev_hits + cur.hits, + prev_misses + cur.misses, + cur.maxsize, + cur.currsize, + ) + + def new_cache_clear(): + nonlocal prev_hits, prev_misses + cur = wrapped_f.cache_info() + prev_hits += cur.hits + prev_misses += cur.misses + old_cache_clear() + + wrapped_f.cache_clear = new_cache_clear + wrapped_f.cumulative_cache_info = cumulative_cache_info + if log.isEnabledFor(logging.DEBUG): + atexit.register(log_lru_cache_stats, wrapped_f) + return wrapped_f + + return inner + +# These are modules that contain generic code for interacting with ShapeEnv +# which are unlikely to identify a particular interesting guard statement +@lru_cache(None) +def uninteresting_files() -> Set[str]: + import torch._inductor.sizevars + import torch._library.fake_impl + import torch._subclasses.meta_utils + import torch._subclasses.fake_tensor + mods = [ + sys.modules[__name__], + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, + torch, + torch._inductor.sizevars, + torch._library.fake_impl, + torch._subclasses.meta_utils, + torch._subclasses.fake_tensor, + ] + return {inspect.getfile(m) for m in mods} + +# We don't bother with the metaclass as all of the dispatching logic happens +# entirely from Python +# +# Didn't bother with ancestors for now, unlikely to have multiple modes for +# symints right now + +class ConstraintViolationError(RuntimeError): + pass + +def has_symbolic_sizes_strides(elem) -> bool: + return elem._has_symbolic_sizes_strides + +Int = Union[torch.SymInt, int] + +def create_contiguous(shape: Sequence[Int]) -> List[Int]: + strides: List[Int] = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) + return list(reversed(strides)) + +def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: + """ + Retrieve the hint for an int (based on the underlying real values as observed + at runtime). If no hint is available (e.g., because data dependent shapes), + if fallback is not None, use that instead (otherwise raise an error). + """ + if isinstance(a, torch.SymInt): + return a.node.require_hint(fallback) + assert type(a) is int, a + return a + +Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] + +def has_hint(a: Scalar) -> bool: + if isinstance(a, SymTypes): + return a.node.has_hint() + return True + +def is_concrete_int(a: Union[int, SymInt]) -> bool: + r""" Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or int): Object to test if it int + """ + assert isinstance(a, (SymInt, int)) + + if isinstance(a, int): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Integer): + return True + + return False + +# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. +# So make sure only type checker evaluates this alias. +# Xref: https://www.internalfb.com/diff/D53324783 +SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" + +def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: + """ + Perform a guard on a symbolic boolean expression in a size oblivious way. + This is typically used when a non-oblivious test would result in a guard + on a data dependent value of which we don't know the value of at compile time. + When a guard is tested this way, we may diverge in behavior from how regular + PyTorch semantics would treat it. For more information, see + https://github.com/pytorch/pytorch/pull/118579 + """ + if isinstance(expr, torch.SymBool): + return expr.node.guard_size_oblivious("", 0) + else: + assert isinstance(expr, bool), expr + return expr + +def check_consistent(new, old) -> None: + """ + Test that two "meta" values (typically either Tensor or SymInt) have + the same values, e.g., after retracing. If we don't understand the + quantities in question, we'll just skip the consistency check. + """ + # TODO: do boolean equality test too, see + # https://github.com/pytorch/pytorch/issues/124110 + scalar_types = (torch.SymInt, torch.SymFloat, int, float) + + if isinstance(new, torch.Tensor): + assert isinstance(old, torch.Tensor) + torch._check(old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)") + # Do this manually so that each individual test is irrefutable + # (TODO: should be a helper for this, maybe sym_eq? That + # gives us a compound expression and I'm not sure it + # simplifies right now) + for i, j in zip(old.shape, new.shape): + torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") + # NB: bool is subclass of int + elif isinstance(new, scalar_types) and not isinstance(new, bool): + assert isinstance(old, scalar_types) and not isinstance(old, bool), f"{old} != {new}" + torch._check(old == new, lambda: f"{old} != {new} (old != new)") + +def resolve_unbacked_bindings(shape_env, bindings): + if bindings is None: + return None + return { + shape_env.unbacked_renamings.get(k, k): v + for k, v in bindings.items() + } + +def rebind_unbacked(shape_env, n: torch.fx.Node, result): + """ + Suppose we are retracing a pre-existing FX graph that previously had + fake tensor propagation (and therefore unbacked SymInts). When we retrace, + we re-propagate fake tensors, which results in new unbacked SymInts. + When this happens, we need to tell the shape environment about the equivalence + of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which + has the old binding information) and the new result (which we can extract the + new unbacked SymInts out from). + """ + from torch._dynamo.tensor_version_op import _tensor_version + + # Inputs never need rebinding + if n.op == "placeholder": + return + + if bindings := resolve_unbacked_bindings(shape_env, n.meta.get("unbacked_bindings")): + for raw_u0, path in bindings.items(): + u1 = pytree.key_get(result, path) + # tensor_version ops get specialized after AOTAutograd, it's OK, + # we don't actually want to do asserts on them. This is all a bit + # questionable though + if isinstance(u1, int) and n.target is _tensor_version: + log.info("rebind_unbacked: discard _tensor_version %s %s -> %s", raw_u0, path, u1) + continue + raw_u1 = u1.node.expr + # Simplify SymBool binding + if ( + isinstance(raw_u1, sympy.Piecewise) and + len(raw_u1.args) == 2 and + raw_u1.args[0][0] == 1 and + isinstance(eq := raw_u1.args[0][1], sympy.Eq) and + isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) and + shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) and + eq.rhs == 1 and + raw_u1.args[1] == (0, True) + ): + # This is what the pattern match above is testing + repacked = _sympy_cast_symbool_to_symint_guardless(sympy.Eq(new_raw_u1, 1)) + assert repacked == raw_u1, f"{repacked} != {raw_u1}" + # Cancel the to_int(to_bool(x)). This is sound because x in + # [0, 1] + raw_u1 = new_raw_u1 + assert isinstance(raw_u1, sympy.Symbol) + # The old and new could be the same if you improperly hit the memo + # while retracing. Make sure you updated FakeTensorMode.epoch + assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster" + # Reuse the OLD symbol name + shape_env._rename_unbacked_to(raw_u1, raw_u0) + +# NB: You could try to expand this to cover more cases by simply +# detecting whenever you have an int output, but this is a bit +# dangerous in case someone adds a function that returns an int but is +# mutating. So manually whitelist for now. +def is_accessor_node(node: torch.fx.Node) -> bool: + # Dynamo only exercised condition + if ( + node.op == "call_method" + and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) + and node.target in ["size", "stride", "storage_offset", "item"] + ): + return True + if node.op == "call_function" and node.target in [ + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.default, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_storage_offset, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.sym_numel.default, + ]: + return True + return False + +def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: + r""" Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) + +def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: + """ + After canonicalization, we are guaranteed to have eliminated Ge/Gt relations + (rewriting them to Le/Lt, respectively). + """ + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + if isinstance(expr, tuple(opposite.keys())): + rhs = expr.lhs - expr.rhs + t = opposite[type(expr)] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + rhs = expr.rhs - expr.lhs + t = type(expr) + + def is_neg(t): + return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative) + + lhs = 0 + rhs = _reduce_to_lowest_terms(rhs) + if isinstance(rhs, sympy.Add): + pos = [] + neg = [] + for term in rhs.args: + if is_neg(term): + neg.append(-term) + else: + pos.append(term) + lhs = sympy.Add(*neg) + rhs = sympy.Add(*pos) + elif is_neg(rhs): + # lhs == 0 + lhs, rhs = -rhs, 0 + return t(lhs, rhs) + + +def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: + """ + Eliminates any integer factor from a given expression. + E.g., 6x + 4y reduces to 3x + 2y. + + Useful when an expression is == or != to 0. + """ + def integer_coefficient(x): + if isinstance(x, sympy.Integer): + return abs(int(x)) + elif isinstance(x, sympy.Mul): + return math.prod([abs(int(arg)) for arg in x.args if isinstance(arg, sympy.Integer)]) + else: + return 1 + + if isinstance(expr, sympy.Add): + atoms = expr.args + factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) + atoms = [x / factor for x in atoms] + return sympy.Add(*atoms) + else: + return expr / integer_coefficient(expr) + + +def is_concrete_bool(a: Union[bool, SymBool]) -> bool: + r""" Utility to check if underlying object + in SymBool is concrete value. Also returns + true if integer is passed in. + Args: + a (SymBool or bool): Object to test if it bool + """ + assert isinstance(a, (SymBool, bool)) + + if isinstance(a, bool): + return True + + if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)): + return True + + return False + +def is_nested_int(s): + return isinstance(s, torch.SymInt) and s.node.is_nested_int() + +def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]: + if isinstance(val, SymTypes): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node.expr + elif isinstance(val, sympy.Basic): + yield val + elif isinstance(val, (int, float, bool)): + pass + elif isinstance(val, (tuple, list)): + for s in val: + yield from _iterate_exprs(s) + elif is_sparse_any(val): + yield from _iterate_exprs(val.size()) + elif isinstance(val, torch.Tensor): + yield from _iterate_exprs(val.size()) + yield from _iterate_exprs(val.stride()) + yield from _iterate_exprs(val.storage_offset()) + elif val is None: + pass + else: + raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") + +def free_symbols(val: Union[SymInt, sympy.Expr, torch.Tensor]) -> Set[sympy.Symbol]: + if val is None: + return set() + itr = _iterate_exprs(val) + # we need at least 1 to call union, so we hand code the identity + try: + first_expr = next(itr) + except StopIteration: + return set() + + return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) + +def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool: + """Faster version of bool(free_symbols(val))""" + return not all(e.is_number for e in _iterate_exprs(val)) + +# Like free_symbols, but filtered to only report unbacked symbols +def free_unbacked_symbols(x): + # NB: keep synced with is_unbacked_symint + return {s for s in free_symbols(x) if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))} + +# WARNING: Don't use this on Dynamo produced graphs, they don't have meta +# setup! +def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]: + if ( + "val" in node.meta and + isinstance(node.meta["val"], torch.SymInt) and + isinstance(node.meta["val"].node.expr, sympy.Symbol) and + (node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr)) + ): + return node.meta["val"].node.expr + return None + +def find_symbol_binding_fx_nodes(graph): + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r: + r[node.meta["val"].node.expr] = node + return r + + +# Analogous to ConvertIntSource +@dataclass(frozen=True) +class ConvertIntKey: + def __str__(self) -> str: + return ".cast_symbool_to_symint_guardless()" + + def get(self, b: bool) -> int: + """Get the int value from bool""" + return cast_symbool_to_symint_guardless(b) + + +@dataclass(frozen=True) +class CallMethodKey: + name: str + + def __str__(self) -> str: + return f".{self.name}()" + + def get(self, o: Any) -> Any: + """Call the method on object""" + return getattr(o, self.name)() + + +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + +@dataclass(frozen=True) +class DivideByKey: + divisor: int + + def __str__(self) -> str: + return f".__floordiv__({self.divisor})" + + def get(self, o: int) -> int: + """Divide object by divisor""" + return o // self.divisor + + +def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, peek=False): + """ + After having run fake tensor propagation and producing example_value + result, traverse example_value looking for freshly bound unbacked + symbols and record their paths for later. It is an error if + we have allocated an unbacked SymInt but it cannot be found in + example_value. (NB: this means if you have a multi-output + function, you must call this on the tuple of tensor output, you + cannot wait!) + + The peek parameter lets you check out what the bindings are without + changing the affected list. This is primarily useful for ensuring + unbacked_var_to_val is promptly populated when propagate_real_tensors is on. + """ + if shape_env is None: + return + fs = shape_env.pending_fresh_unbacked_symbols + pending = set(fs) + if pending: + if not peek: + log.info("compute_unbacked_bindings %s", fs) + fs.clear() + + def free_unbacked_symbols_with_path( + a, path, real=None + ) -> Dict[sympy.Symbol, pytree.KeyPath]: + r = {} + if isinstance(a, (tuple, list)): + for i in range(len(a)): + r.update( + free_unbacked_symbols_with_path( + a[i], path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None + ) + ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update( + free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),)) + ) + elif isinstance(a, torch.Tensor): + r.update( + free_unbacked_symbols_with_path( + a.size(), path + (CallMethodKey("size"),), + real=a.real_tensor.size() if a.real_tensor is not None else None + ) + ) + r.update( + free_unbacked_symbols_with_path( + a.stride(), path + (CallMethodKey("stride"),), + real=a.real_tensor.stride() if a.real_tensor is not None else None + ) + ) + r.update( + free_unbacked_symbols_with_path( + a.storage_offset(), path + (CallMethodKey("storage_offset"),), + real=a.real_tensor.storage_offset() if a.real_tensor is not None else None + ) + ) + + # NB: Intentionally access _expr, not expr, do not want + # simplification! + elif ( + isinstance(a, (torch.SymInt, torch.SymFloat)) + and isinstance(s := a.node._expr, sympy.Symbol) + and s in pending + ): + r[s] = path + if real is not None: + shape_env.set_unbacked_var_to_val(s, real) + pending.remove(s) + # When an unbacked SymInt is perfectly divisible by an integer + # constant, we replace it with the integer constant to improve + # reasoning capabilities. However, in synthetic examples, it is + # then possible that the factor never is explicitly allocated. + # Fortunately, we can compute it by division. + elif ( + isinstance(a, torch.SymInt) + and isinstance(s := a.node._expr, sympy.Mul) + and len(s.args) == 2 + and isinstance(lhs := s.args[0], sympy.Integer) + and isinstance(rhs := s.args[1], sympy.Symbol) + and rhs in pending + ): + # TODO: DivideByKey needs to test divisibility at runtime! + r[s] = path + (DivideByKey(int(lhs)),) + if real is not None: + shape_env.set_unbacked_var_to_val(s, real // int(lhs)) + pending.remove(rhs) + # The annoyance here arises from the fact that SymBool is + # allocated by allocating a SymInt and then testing if it's equal + # to one. So you have a complicated binding site logic for this. + elif ( + isinstance(a, torch.SymBool) + and isinstance(s := a.node._expr, sympy.Eq) + # This must match create_unbacked_symbool EXACTLY + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + and s.lhs in pending + ): + r[s.lhs] = path + (ConvertIntKey(),) + if real is not None: + shape_env.set_unbacked_var_to_val(s, int(real)) + pending.remove(s.lhs) + + return r + + symbol_to_path = free_unbacked_symbols_with_path(example_value, ()) + if not peek and pending: + extra = ( + repr((example_value.stride(), example_value.storage_offset())) + if isinstance(example_value, torch.Tensor) + else "" + ) + raise PendingUnbackedSymbolNotFound( + f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n" + "Did you accidentally call new_dynamic_size() or item() more times " + "than you needed to in your fake implementation?\n" + "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit" + ) + + # Why do we have to do some rebinding here? If the original FX node + # wasn't a binding site because you had a memo hit, but post + # translation you aren't a memo hit anymore, there's now a new binding + # site... but we know (because it's the same FX node) that the value + # is actually the same, they're just not obviously equal anymore. + # + # The logic here is written carefully, because unlike the + # bind_unbacked case, we are not guaranteed to have a symbol for + # old_sym. If we have a symbol, do regular rename unbacked to; but if + # we don't, we need to specially eliminate the fresh unbacked symbol + # (NB: we are /trusting/ that the memoization is correct, and that we + # don't need to generate a new runtime assert. This is load bearing, + # as repropagation can happen after we've frozen runtime asserts.) + if old_example_value is not None: + for keypath in symbol_to_path.values(): + old_sym = pytree.key_get(old_example_value, keypath) + new_sym = pytree.key_get(example_value, keypath) + if ( + isinstance(new_sym, SymTypes) and + isinstance(new_s := new_sym.node.expr, sympy.Symbol) + ): + if isinstance(old_sym, SymTypes) and (old_s := old_sym.node.expr) != new_s: + if isinstance(old_s, sympy.Symbol): + shape_env._rename_unbacked_to(new_s, old_s) + else: + shape_env._eliminate_unbacked(new_s, old_s) + elif not isinstance(old_sym, SymTypes): + shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) + + return symbol_to_path + +def definitely_true(a): + """ + Returns True only if we can tell that a is True, possibly introducing + a guard in the process. If a depends on some unbacked SymInt, we may + return False even though there may exist a possible value of the SymInt + that would cause the expression to return True. + + When is it appropriate to use definitely_true? First, if you can use + a higher level combinator like parallel_or/parallel_and, prefer using + those instead, they are definitely safe (modulo short-circuiting). + Second, it can be used if the program would behave equivalently if + definitely_true always returned False (parallel_or/parallel_and are + examples of this pattern, modulo short-circuiting). Finally, it even + be OK if the program wouldn't behave equivalently, so long as the + change is semantics preserving. It can be semantics preserving if + the program errors in more cases than it did previously (but otherwise + behaves identically), or if it changes some quantity in a way that + doesn't matter (e.g., strides often fall in this bucket.) + """ + if isinstance(a, SymBool): + if a.node.has_hint(): + return guard_bool(a) + else: + return False + return bool(a) + +def definitely_false(a): + """ + Returns True only if we can tell that a is False, possibly introducing + a guard in the process. If a depends on some unbacked SymInt, we may + return False even though there may exist a possible value of the SymInt + that would cause the expression a to be False. See definitely_true + for more usage guidance. + """ + if isinstance(a, SymBool): + if a.node.has_hint(): + return not guard_bool(a) + else: + return False + return not bool(a) + +def statically_known_true(x: Union[bool, SymBool]) -> bool: + """Returns True if x can be simplified to a constant and is true. + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to true at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + + """ + if isinstance(x, SymBool): + expr = x.node.expr + shape_env = x.node.shape_env + try: + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr) + return False + assert isinstance(x, bool) + return x + + +def parallel_or(*args): + """ + Evaluate the logical OR of several arguments, avoiding guarding on + unbacked SymInts if another argument is definitely True. + """ + if any(statically_known_true(a) for a in args): + return True + if any(definitely_true(a) for a in args): + return True + return any(args) + +def parallel_and(*args): + """ + Evaluate the logical FALSE of several arguments, avoiding guarding on + unbacked SymInts if another argument is definitely False. + """ + if any(statically_known_true(torch.sym_not(a)) for a in args): + return False + if any(definitely_false(a) for a in args): + return False + return all(args) + +def sym_eq(x, y): + """ + Like ==, but when run on list/tuple, it will recursively test equality + and use sym_and to join the results together, without guarding. + """ + if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)): + if len(x) != len(y): + return False + return functools.reduce(operator.and_, map(sym_eq, x, y), True) + elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): + return x == y + else: + raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") + +def guard_scalar(a): + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + + +def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int): + shape_env.constrain_symbol_range(s, compiler_min, compiler_max) + + +def _advise_is_size(a): + """ + Don't use this directly; use torch._check_is_size instead. + + This is a softer version of _constrain_range_for_size (with min=0, + max=Inf). Instead of forcibly constraining a variable (and erroring if we + failed to constrain it), it will simply advise us that a size is + constrained in some way. We will always defer a runtime assert for this + constraint if we cannot prove it at compile-time, but we we only + *sometimes* learn useful extra information at compile-time with this + information. This is in contrast to constrain_range_for_size, where if + you don't call that on a fresh unbacked symint, chances are we will choke. + + TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed + code. Right now this is only really used in code with AOTAutograd trace + through, so it is not a big problem that this isn't supported, but in + principle all of this code should be Dynamo'able too. + + TODO: I didn't support min/max because I didn't have a use case where this + actually helped. In principle we can support it, it just makes the + implementation below more complicated. + """ + + # This must always succeed, because the sole allowed caller _check_is_size + # was responsible for expect_true'ing this + # This assert triggers expensive sym compute, do not do it until its cheap. + # assert a >= 0 + + # NB: it's important not to constrain range for size for *hinted* SymInts, + # because it is not only unsound, it will immediately trip our asserts + # that hints have to be consistent with static analysis! If you somehow + # have an unbounded SymInt that later constrains to 1, this will be + # inconsistent with the range + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + ): + _constrain_range_for_size(a) + +def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt), "can only constrain range for SymInt" + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + + a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) + + +# inclusive both ways +def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): + """ + Applies a constraint that the passed in SymInt must lie between min-max + inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning + that it can be used on unbacked SymInts). If min/max are None, we assume + that the dimension is unbounded in that direction. Repeated application + of constrain_range intersects the ranges. This is a fairly low level API + that doesn't have a lot of safety guarantees (TODO: provide higher level + APIs). + + Currently, we use this API in the following circumstance: when we allocate + an unbacked SymInt, denoting an integer quantity which is data dependent, + we ordinarily do not know anything about what values it may take. This + means that any sort of guard on it will immediately fail. However, in + many cases, we know something about the unbacked SymInt: for example, we + know that nonzero(x).size(0) must be >= 0. We use constrain_range to + narrow the possible range, declaring that negative symbols are impossible. + This permits to definitely answer True to queries like 'nnz >= 0', even if + we don't know what the actual (hinted) value of 'nnz' is. In fact, we + actually use constrain_range to unsoundly discharge common guards: for an + unbacked SymInt produced by nonzero, we will also assume that it is not + equal to 0/1 (even though these are perfectly possible values at runtime), + because we generally expect graphs that are valid for N=2 to also be valid + for N=1. + """ + if min is None: + min = -int_oo + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") + return + + a.node.shape_env._constrain_range(a.node.expr, min, max) + +def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + return + else: + shape_env = b.node.shape_env + else: + shape_env = a.node.shape_env + + shape_env._constrain_unify(a, b) + +# Assume that a boolean is true for the purposes of subsequent symbolic +# reasoning. This will keep track of corresponding runtime checks to verify +# that the result is upheld: either as a regular guard, or as a special set +# of asserts which are triggered when an unbacked SymInt is allocated. +# +# DO NOT use this function for these cases: +# +# - This is inappropriate for "branching" conditions (where both +# true and false result in valid programs). We will always assume +# the condition evaluates true, and so it will never be possible +# to trace the false condition when you use it. For true branching +# on unbacked SymInts, you must use torch.cond; if you incorrectly +# use expect_true in this case, you will make the false branch +# unreachable (as we will simply assume that only the true branch +# is ever exercised). +# +# - This is inappropriate for situations where you know some other system +# invariant guarantees that this property holds, since you don't +# really need to insert a runtime check in that case. Use something +# like constrain_range in that case. +# +# This API has a hitch. To avoid having to reimplement error reporting +# capabilities, this function CAN return False. The invariant is that +# the surrounding code must raise an error when this function returns +# False. This is quite low level, so we recommend using other functions +# like check() which enforce this in a more intuitive way. +# +# By the way, this name is a nod to the __builtin_expect macro, +# which is used similarly (but unlike __builtin_expect, you MUST fail +# in the unlikely branch.) (I think expect is a good name; in recent +# versions of C++, this is replaced with [[likely]], which is weaker +# and not accurate for this function!) +def expect_true(a, skip: int = 0): + if isinstance(a, SymBool): + # TODO: check perf implications of this + frame = inspect.currentframe() + for _ in range(skip + 1): # always run this loop at least once + frame = frame.f_back + return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno) + assert type(a) is bool, a + return a + +def guard_bool(a): + if isinstance(a, SymBool): + return a.node.guard_bool("", 0) # NB: uses Python backtrace + assert type(a) is bool, a + return a + +def guard_int(a): + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert type(a) is int, a + return a + +def guard_float(a): + if isinstance(a, SymFloat): + return a.node.guard_float("", 0) # NB: uses Python backtrace + assert isinstance(a, float), a + return a + +# Given a GraphModule, return all the FakeTensors for all the placeholders +def fx_placeholder_vals(gm): + return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"] + +def fx_placeholder_targets(gm): + return [n.target for n in gm.graph.nodes if n.op == "placeholder"] + +# Given a GraphModule and arguments to run it with, evaluate that the guards +# for its associated ShapeEnv are satisfied by the passed arguments. This +# WILL check for duck sizing. +def eval_guards(gm, *args, ignore_static=True): + return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static) + +def bind_symbols(gm, *args): + return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) + +class DimDynamic(Enum): + """ + Controls how to perform symbol allocation for a dimension. It is always + sound to default this to DYNAMIC, but the policies DUCK and STATIC can + result in better trace-time and compile-time performance, as they reduce + the number of allocated symbols and generally make your graph more static. + + NB: If we notice you've applied a constraint to the dimension, we will + force it to DYNAMIC for simplicity. + + DimDynamic is controlled by a variety of higher level UX features. + Currently: + + - In eager mode, the default policy is DUCK. + - The default is changed to STATIC with assume_static_by_default. + - An individual dim is marked DYNAMIC if you mark_dynamic_dim. + - In export mode, the default policy is STATIC. + - An individual dim is marked DYNAMIC if you specify it in + dynamic_shapes passed to export. + """ + # Treat the dimension symbolically + DYNAMIC = 0 + # Treat the dimension symbolically, but if its hint matches another + # dynamic dimension, unify the two symbols ("duck sizing") + DUCK = 1 + # Treat the dimension statically based on its hint + STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 + # Infer the strides from stride. If size is static, strides will be static as well. + INFER_STRIDE = 4 + + +# NB: These constraints affect both clients and backends: given some +# constraint C, the client must pass inputs that satisfy the constraint, +# while a backend must not introduce guards BEYOND this constraint. +# For clarity, we document the implications on both sides for both the client +# and the backend. +# +# NB: These constraints are on a *single* dimension. In principle, we could +# also have multi-dimension constraints, but our guess is that this is not +# actually useful and so we are not supporting it right now. +# +# NB: Strict constraints are typically only suitable for export, as in eager +# a backend like inductor may validly introduce extra, discretionary guards +# to improve performance of code. A StrictMinMaxConstraint would be brittle +# under future optimizations performed by inductor; we don't guarantee +# eager code with StrictMinMaxConstraint will keep working in the future! + +@dataclass(frozen=True) +class Constraint: + warn_only: bool + +@dataclass(frozen=True) +class StrictMinMaxConstraint(Constraint): + """ + For clients: the size at this dimension must be within 'vr' (which + specifies a lower and upper bound, inclusive-inclusive) AND it + must be non-negative and should not be 0 or 1 (but see NB below). + + For backends: there must not be any guards on this dimension which + are not implied by the given lower and upper bound. Regardless of + the lower bound, the backend can assume the size is non-negative + and that it is not 0 or 1. + + An unbounded StrictMinMaxConstraint can be thought of as a strict version + of "RelaxedUnspecConstraint". + + NB: Export will often unsoundly assume that a graph works for 0/1, even + though at trace time we assumed size is not 0 or 1. The idea is that + if we produce a graph that works for a range of values, it will be OK + for N=0/1 too. + """ + vr: ValueRanges + + def render(self, source: Source): + """Format the constrain equation""" + # TODO: better printing for -oo and oo + return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" + +@dataclass(frozen=True) +class RelaxedUnspecConstraint(Constraint): + """ + For clients: no explicit constraint; constraint is whatever is implicitly + inferred by guards from tracing. + + For backends: there must exist at least TWO possible values for the + size at this dimension which satisfy the guards for this dimension. + + In other words, this constraint helps us distinguish between "we don't + care if this dimension specializes or not" versus "this dimension must be + unspecialized." However, this constraint doesn't say very much about what + specialization is permitted; for example, if we guard on a size being + even, this would still be acceptable under an unspec constraint. This + makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler + may add constraints to otherwise dynamic dimensions; we can't assert that + there are NO guards as this is brittle because compilers should be able to + add extra constraints. If you want to assert that there are no guards, + use StrictMinMaxConstraint with an unbounded ValueRanges. + """ + def render(self, source: Source): + return f"RelaxedUnspecConstraint({source.name()})" + +# NB: None here indicates the client constraint is whatever is implicitly +# inferred by guards from tracing, and that a backend can add whatever guards +# it wants (including fully specializing the value). +DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] + +@dataclass(frozen=True) +class EqualityConstraint(Constraint): + """ + Represent and decide various kinds of equality constraints between input sources. + + A "source pair" is a pair of input sources for dynamic dimensions that + are specified equal. We represent `source_pairs` in a union-find forest + so that we can efficiently check whether two such sources are transitively equal. + + A "derived equality" relates an input source to an expression over a root. + The root can be another input source, corresponding to some dynamic dimension, + or a phantom symbol that does not directly represent any dynamic dimension. We + represent `derived_equalities` involving input sources in a transitively-closed map + so that we can efficiently check whether an input source is transitively equal to + a given expression over another input source. + (NOTE: In contrast, it is easy to decide whether an input source is transitively equal + to a given expression over a phantom symbol; such expressions are already in canonical + form and so the problem reduces to symbolic expression equality.) + """ + source_pairs: List[Tuple[Source, Source]] + derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]] + phantom_symbols: List[sympy.Symbol] + + def __post_init__(self): + """Pre-processing to answer queries `is_equal` and `is_derived` below. + + Example: Suppose we are given: + source_pairs [a = b, b = c] + derived_equalities [d = c + 1, e = d - 1] + We first construct a union find with source_pairs: + _parents = {a: a, b: a, c: a} + Then we compute canonical symbolic expressions, recursively applying derived_equalities + until we bottom out: + _defs = {d: c + 1, e: (c + 1) - 1 aka c} + """ + + # self._parents is a map from input sources to input sources where, conceptually, + # these are directed edges in a union-find forest + _parents: Dict[Source, Source] = {} + object.__setattr__(self, "_parents", _parents) + # self._defs is a map from input sources to "canonical" symbolic expressions, + # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., + # not derived Dims) + _defs: Dict[Source, sympy.Expr] = {} + object.__setattr__(self, "_defs", _defs) + + for source1, source2 in self.source_pairs: + # preprocess into a union-find forest + self._union(self._find(source1), self._find(source2)) + for source, root, fn in self.derived_equalities: + # preprocess into a transitively-closed map + # NOTE(avik): we reuse the union-find forest for canonicalizing input sources + if isinstance(root, sympy.Symbol): + self._defs[self._find(source)] = fn(root) + else: + self._defs[self._find(source)] = fn(self._rewrite(root)) + + def _find(self, source): + # chase edges to find the root of this equivalence class + if source in self._parents: + return self._find(self._parents[source]) + else: + return source + + def _union(self, root1, root2): + # merge two equivalence classes by adding an edge from one root to the other + if root1 != root2: + self._parents[root1] = root2 + + def _rewrite(self, src): + # always represent the given source by the root of its equivalence class + src = self._find(src) + if src in self._defs: + # simply look up the definition if it exists + # NOTE(avik): This works because definitions are always transitively-closed; + # otherwise we would have to do recursive rewriting. + return self._defs[src] + else: + # otherwise, create a symbol representing the source + return sympy.Symbol(src.name()) + + def is_equal(self, source1, source2): + return ( + # check whether source1 and source2 have the same root + self._find(source1) == self._find(source2) or + # check whether source1 is derived equal to source2 + self.is_derived(source1, source2, lambda x: x) + ) + + def is_derived(self, src, symbol_src, fn): + # check whether both src and symbol_src have the same definition + return self._rewrite(src) == fn(self._rewrite(symbol_src)) + + +def _assert_symbol_context(symbolic_context): + assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object" + assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC" + +def _is_supported_equivalence(expr): + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return ( + (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs)) + ) + return isinstance(expr, sympy.Symbol) + +def _has_uninterpretable_sympy_function(expr) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ + return expr.has( + torch.utils._sympy.functions.ToFloat, + torch.utils._sympy.functions.TruncToInt, + torch.utils._sympy.functions.CeilToInt, + ) + +@dataclass(frozen=True) +class SymbolicContext: + """ + Data structure specifying how we should create symbols in + ``create_symbolic_sizes_strides_storage_offset``; e.g., should + they be static or dynamic. + + This is an abstract base class because we are probably going to add + another version of this that says "use exactly these SymInts, don't + allocate fresh symbols." + """ + + +@dataclass(frozen=True) +class StatelessSymbolicContext(SymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. + This will cause fresh symbols to be allocated + """ + dynamic_sizes: DimList[DimDynamic] + dynamic_strides: DimList[DimDynamic] = None + constraint_sizes: DimList[DimConstraint] = None + constraint_strides: DimList[DimConstraint] = None + # If the tensor is a view, this should be populated for the base. It contains + # information on how to allocate symbols when recursively fakeifying the base + # during view fake-ification. + view_base_context: Optional[SymbolicContext] = None + # TODO: add storage offset and stride symbolic_context + + def __post_init__(self): + if self.dynamic_strides is None: + object.__setattr__(self, 'dynamic_strides', [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes)) + if self.constraint_sizes is None: + object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes)) + if self.constraint_strides is None: + object.__setattr__(self, 'constraint_strides', [None] * len(self.dynamic_sizes)) + assert all(stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) for stride in self.dynamic_strides) + + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owners responsibility to maintain the lifecycle of the cache + w/r/t different shape_envs, clearing, etc. + """ + tensor_source: Source = None + # Why is this keyd on int first? + # That integer is actually the id of the shape_env. This cache short-circuits symbol + # creation, and we must store it per shape env. Now, while tracing invariants are a single + # shape env per tracing context, and every new frame gets a new shape_env. So where would we have + # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events + # is invoked, and creates a new shape_env. Replaying events against this new shape_env will + # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never + # get recorded in var_to_val, etc. + # TODO(voz): consider a weakref to the shape_env here + shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None + + def __post_init__(self): + super().__post_init__() + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.shape_env_to_source_to_symbol_cache: + object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {}) + + +@dataclass(frozen=True) +class SubclassSymbolicContext(StatefulSymbolicContext): + """ + The correct symbolic context for a given inner tensor of a traceable tensor subclass + may differ from that of the outer symbolic context. This structure allows for this + flexibility, with inner symbolic contexts mapped via attr -> symbolic context. + """ + inner_contexts: Dict[str, SymbolicContext] = None + + def __post_init__(self): + super().__post_init__() + if self.inner_contexts is None: + self.inner_contexts = {} + + +def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool: + if isinstance(val, (int, float, bool)): + return False + return val.node.is_symbolic() + +IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + +@lru_cache(256) +def safe_expand(r): + if hasattr(r, 'expand'): + try: + return sympy.expand(r) + except RecursionError: + log.warning("RecursionError in sympy.expand(%s)", r) + return r + else: + return r + +def error(): + raise AssertionError("shouldn't be hit") + + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense(sizes, strides): + return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) + +def _eval_is_non_overlapping_and_dense(sizes, strides): + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted( + zip(sizes, strides), key=operator.itemgetter(1) + ) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def _sympy_cast_symbool_to_symint_guardless(x: sympy.Expr) -> sympy.Expr: + return sympy.Piecewise((1, x), (0, True)) + + +def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt: + if isinstance(symbool, bool): + return 1 if symbool else 0 + int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) + return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None) + +SYMPY_INTERP = { + 'Abs': operator.abs, + 'Eq': operator.eq, + 'Ne': operator.ne, + 'Gt': operator.gt, + 'Lt': operator.lt, + 'Le': operator.le, + 'Ge': operator.ge, + 'Min': min, + 'Max': max, + 'Mod': operator.mod, + 'PythonMod': operator.mod, + 'FloorDiv': operator.floordiv, + 'TrueDiv': operator.truediv, + 'PowByNatural': operator.pow, + 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, + 'floor': math.floor, + 'ceiling': math.ceil, + 'FloorToInt': math.floor, + 'FloatPow': math.pow, + 'CeilToInt': math.ceil, + 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, + 'RoundToInt': builtins.round, + 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, + 'FloatTrueDiv': operator.truediv, + 'ToFloat': builtins.float, +} + + +def _lru_cache(fn, maxsize=None): + """ + Wrapper around lru_cache that clears when new info about shapes has been + updated. + + Use lru_cache if the output is always the same, regardless of the + constraints we know now (i.e. evaluate_expr) + + Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. + """ + fn_cache = lru_cache(maxsize)(fn) + prior_version = 0 + + if config.validate_shape_env_version_key: + prior_key = None + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), \ + "ShapeEnv cache key changed without version being updated!" + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) + + wrapper.cache_clear = fn_cache.cache_clear + wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] + return wrapper + + +# This is pretty similar to ShapeGuard but it also comes with a message, +# and is exclusively used for things that MUST be true (unlike guards, +# which can evaluate False, in which case you just choose not to use +# a particular specialization) +@dataclass(frozen=True) +class RuntimeAssert: + expr: sympy.Expr + msg: str = field(repr=False) + stack: str = field(repr=False) + + +# Used for printing SymExprs in compile_fx +class SymExprPrinter(StrPrinter): + def _print_Float(self, expr): + return str(float(expr)) + + +class ShapeGuardPrinter(SymExprPrinter): + def __init__( + self, + symbol_to_source, + source_ref, + var_to_sources, + ): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_ref = source_ref + self.var_to_sources = var_to_sources + + def _print_Not(self, expr): + return 'not {}'.format(self.parenthesize(expr.args[0], PRECEDENCE["Not"])) + + def _print_And(self, expr): + return self.stringify(expr.args, " and ", PRECEDENCE["And"]) + + def _print_Or(self, expr): + return self.stringify(expr.args, " or ", PRECEDENCE["Or"]) + + def _print_Symbol(self, expr) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + + def repr_symbol_to_source(): + return repr({ + symbol: [s.name() for s in sources] + for symbol, sources in self.symbol_to_source.items() + }) + + assert self.symbol_to_source.get(expr), ( + f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " + f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " + "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) + return self.source_ref(self.symbol_to_source[expr][0]) + + +class LoggingShapeGuardPrinter(ShapeGuardPrinter): + def __init__(self, var_to_sources): + super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) + + +class DynamicDimConstraintPrinter(StrPrinter): + """ + Printer for dynamic dim constraints. + - Instead of symbol s_k it prints its source t.size()[i] + - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. + + We use this to suggest code for specifying dynamic dim constraints. + """ + def __init__(self, symbol_to_source, source_name_to_debug_name): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_name_to_debug_name = source_name_to_debug_name + + def _print_Symbol(self, expr) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) + return self.symbol_to_source[expr][0].name() + + def _print_Relational(self, expr): + return f'{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}' + + +class DimConstraints: + """ + Custom solver for a system of constraints on symbolic dimensions. + Solutions are "static" values or simplified "dynamic" constraints. + """ + + def __init__( + self, + symbol_to_source, + var_to_val, + marked_dynamic, + source_name_to_debug_name, + ): + # We try to solve systems of inequalities with 1 free variable. + self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set) + # Among them, we prioritize solving for a free variable that has equalities. + # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() + # and removing a symbol from the former => removing it from the latter. + self._symbols_with_equalities: Set[sympy.Symbol] = set() + # A solution of a free variable with equalities becomes a substitution. + # We use these substitutions to simplify other constraints. + # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. + self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {} + + # In general, constraints may have // and % operations. + # Of course, // can be expressed in terms of / and %. + # Our inequality solver can handle / but not %. So we need to transform them away. + # We do so by using the values of variables as hints to evaluate %. + # For soundness we record additional congruence guards and solve them separately. + self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val + self._congruences: Set[sympy.Expr] = defaultdict(set) + + # We do not try to (directly) solve inequalities with > 1 free variables. + # NOTE: free variables in these inequalities cannot also be in _substitutions. + self._multivariate_inequalities: Set[sympy.Expr] = set() + + # We park external equalities between free variables here. + self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = [] + + # Solutions come in two forms: + # - (static) specializations + # - (dynamic) inequalities / congruences + self._static_results: Set[str] = set() + self._dynamic_results: Set[str] = set() + + # printer for solutions + self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name) + + # inconsistencies found on substituting with concrete values / static solutions + self._inconsistencies: List[str] = [] + + # symbols that are marked dynamic + self._marked_dynamic = marked_dynamic + + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: Set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + + def rewrite_with_congruences(self, s, expr): + """ + Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. + This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. + We solve the added congruences separately (using our congruence solver, see below). + """ + def mod_handler(*args): + # Suppose that we have an expression of the form b % d with free variable s. + # Using the value of s as a "hint," we can evaluate b % d to a value k. + # Then we can rewrite b % d to k while adding the guard b % d == k. + + # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF + # the original expression always evaluates to a constant value (i.e., it does not vary with s). + # In other words, + # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with + # the original expression; + # - while it may be possible to find solutions of s with the original expression that are not + # solutions with the rewritten expression, in that case the original expression cannot evaluate + # to the same value for all solutions of s. + # + # Should we be worried about this incompleteness? No, because of the following reasons: + # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech + # (i.e., "don't let perfect be the enemy of the good"). + # 2. We already have a tradition of using hints to add guards in the compiler for making progress. + # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards + # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. + # + # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. + # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we + # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! + base, divisor = args + base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return mod_reduced + + def floor_div_handler(*args): + # Suppose that we have an expression of the form b // d with free variable s. + # Using the value of s, we can evaluate b % d to a value k. + # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. + + # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d + # and eliminating b % d as above. + base, divisor = args + base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace(self._var_to_val) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha + return (base - mod_reduced) / divisor + + if expr.has(Mod): + expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) + if expr.has(FloorDiv): + expr = expr.replace(FloorDiv, floor_div_handler) + return expr + + def _enumerate_sympy_functions(self): + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) + + def _has_unsupported_sympy_function(self, expr) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + + def add(self, expr) -> bool: + """Add an expression to the set of constraints. + + Return whether the expression is a trivial constraint (i.e., an obvious tautology). + """ + if expr == sympy.true: + return True + orig_expr = expr + orig_reduced = orig_expr.xreplace(self._var_to_val) + # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 + # It is possible that `expr` will fail the consistency check because of + # precision errors. Specifically, on substituting its free symbols with + # their concrete values, we might end up comparing floats. Until we have + # a fix for this issue, we delay raising such failures. See solve(). + if orig_reduced == sympy.false: + self._inconsistencies.append(f"{orig_expr} is inconsistent!") + if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): + # we're not going to do anything useful with these, so drop them + return False + free_symbols = expr.free_symbols + assert free_symbols, f"Did not expect constraint with no free variables: {expr}" + if len(free_symbols) > 1: + # multivariate: record and move on + self._multivariate_inequalities.add(expr) + else: + # univariate: can solve these immediately + s = next(iter(free_symbols)) + # eliminate // and % (see documentation of `rewrite_with_congruences` above) + old_n_congruences = len(self._congruences[s]) + expr = self.rewrite_with_congruences(s, expr) + new_n_congruences = len(self._congruences[s]) + if expr == sympy.true: + return old_n_congruences == new_n_congruences + reduced = expr.xreplace(self._var_to_val) + if reduced == sympy.false: + self._inconsistencies.append( + f"{expr}, obtained by rewriting {orig_expr} with congruences, " + "is inconsistent!" + ) + if isinstance(expr, sympy.Eq): + # special status for symbols that have equalities (see `solve` below) + self._symbols_with_equalities.add(s) + self._univariate_inequalities[s].add(expr) + return False + + def add_equality(self, source, expr): + """Add an equality constraint""" + if expr.is_number: + # specialization, right here + self._static_results.add(f"{source.name()} == {expr}") + else: + # these will resolve to either specializations or dynamic equality constraints + self._symbolic_equivalences.append((source, expr)) + + def _reduce_congruences(self): + reduced_congruences = {} + for s, congruences in self._congruences.items(): + remainder_modulus_pairs = [] + congruences_to_check = set() + for congruence in congruences: + base, divisor = congruence.args + # We are given a congruence of the form base % divisor == 0 with a free variable s. So: + # - we transform this into an equation of the form base = divisor * tmp; + # - we solve this equation for s to get a linear solution with free variable tmp. + tmp = sympy.Symbol("reduce_congruences_tmp", integer=True) + symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) + # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear + # for how to interpret the results. + if s == symbol: + # This means the solution is of the form s = modulus*tmp + remainder. + modulus, remainder = sympy.polys.polytools.div(solution, tmp) + if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer): + # Make sure 0 <= remainder <= modulus. + remainder = remainder % modulus + remainder_modulus_pairs.append((remainder, modulus)) + continue + # This means that we did not get a unique solution to the equation. + # No problem, we will check it. + congruences_to_check.add(congruence) + # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). + # The solution will be a congruence of the form s = r mod m. + # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. + if remainder_modulus_pairs: + remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs) + reduced_congruences[s] = {(s - remainder) % modulus} + substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder} + reduced_congruences[s].update( + congruence for congruence in congruences_to_check + if not sympy.checksol(congruence, substitution) + ) + else: + reduced_congruences[s] = congruences_to_check + + return reduced_congruences + + def _raise_inconsistencies(self): + if self._inconsistencies: + msg = "\n".join(self._inconsistencies) + self._inconsistencies.clear() + raise ValueError(f"The following inconsistencies were found:\n{msg}") + + def solve(self): + """Solve the system of constraint equations to find simplified constraints + """ + self._raise_inconsistencies() + # as long as there are symbols with equalities, solve for them + # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) + while self._symbols_with_equalities: + s = self._symbols_with_equalities.pop() + exprs = self._univariate_inequalities.pop(s) + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + if isinstance(solution, sympy.And): + solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution) + assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}" + symbol, val = solution.args + assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" + # because this is univariate, the solution is a specialization + self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}") + # add this as a substitution to simplify other constraints + self._substitutions[s] = val + + # simplify multivariate inequalities: some of them will now become univariate! + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.xreplace({s: self._substitutions[s]})) + self._raise_inconsistencies() + + # solve linear congruences + # NOTE(avik): We do not need to solve them for symbols that have already been specialized. + reduced_congruences = self._reduce_congruences() + for s, congruences in reduced_congruences.items(): + for congruence in congruences: + # any congruence that cannot be checked becomes a dynamic constraint as well + if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}): + if self._is_supported_congruence(congruence): + base, divisor = congruence.args + tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}" + tmp = sympy.Symbol(tmp_name, integer=True) + from torch._dynamo.source import ConstantSource + self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] + r = try_solve(sympy.Eq(base, divisor * tmp), s) + self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) + + # remaining symbols have only pure inequalities (no equalities) + for s, exprs in self._univariate_inequalities.items(): + try: + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + # because this is univariate, the solution is a dynamic (range) constraint + if isinstance(solution, sympy.Or): + solution = next(iter(arg for arg in solution.args if arg.xreplace(self._var_to_val))) + if isinstance(solution, sympy.And): + for arg in solution.args: + self._dynamic_results.add(self._dcp.doprint(arg)) + else: + self._dynamic_results.add(self._dcp.doprint(solution)) + except (NotImplementedError, AssertionError) as e: + log.warning("Failed to reduce inequalities: %s", e) + for expr in exprs: + self._dynamic_results.add(self._dcp.doprint(expr)) + + # simplify symbolic equivalences: some of them will now become specializations! + symbolic_equivalences = self._symbolic_equivalences + self._symbolic_equivalences = [] + for source, expr in symbolic_equivalences: + self.add_equality(source, expr.xreplace(self._substitutions)) + + # remaining symbolic equivalences become dynamic equality constraints + for source, expr in self._symbolic_equivalences: + self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr)}") + + @classmethod + def _is_supported_congruence(cls, congruence): + base, divisor = congruence.args + # Congruences that can be currently expressed with supported Dim ops are + # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. + # This allows us to derive x as b*y - a for some Dim y. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(base, sympy.Add): + lhs, rhs = base.args + cond = ( + (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or + (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) + ) + else: + cond = isinstance(base, sympy.Symbol) + cond = cond and isinstance(divisor, sympy.Integer) + return cond + + def forced_specializations(self): + """Returns a dictionary of the names of symbols to their specialized value + """ + def debug_name(src): + name = src.name() + if self._dcp.source_name_to_debug_name: + return f"{self._dcp.source_name_to_debug_name[name]} = {name}" + else: + return name + + return { + debug_name(self._dcp.symbol_to_source[s][0]): val + for s, val in self._substitutions.items() + if s in self._marked_dynamic + } + + def _is_derived_dim(self, dim): + return isinstance(dim, torch.export.dynamic_shapes._DerivedDim) + + def _is_dim(self, dim): + return ( + isinstance(dim, torch.export.dynamic_shapes._Dim) + and not isinstance(dim, torch.export.dynamic_shapes._DerivedDim) + ) + + def _process_derived_dim_roots( + self, + results: Dict[str, Dict[str, Any]], + name_to_dim: Dict[str, Any], + ) -> None: + ''' + Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, + and 2) root swapping. + + 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests + dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final + suggested fixes handle this correctly, but we can get intermediate results that look like + {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}} + and this routine prettifies this by unifying to a single root, and making each suggestion + either a derived dim or min/max range, not both. + + 2) With suggested fixes for derived dims, roots can be swapped, + e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name, + since this leads to messages like "dx - 1 = Dim("dx - 1", ...)". + Instead we evaluate the new root value, and remove results for its derivations. + + First we find all the original roots (specified in dynamic_shapes), that are found in the + values of results (i.e. used for computing suggesting fix values). These original roots + (suppose `dx`) are either specialized, unchanged, refined, or swapped + (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value + in results, and remove suggestions for derivations of `dx`, assuming the derived relation + is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value, + and then do the same with `dx`'s derivations. + + Assuming the originally specified derived relations are correct is valid, because: + 1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1)) + produce_guards() will catch this and crash before hand. + 2) if the relations are numerically correct but do not match the emitted guard, + for example: + + def forward(self, x, y): + return x.reshape([-1]) + y # guard: s0 * 2 = s1 + inputs = (torch.randn(6, 2), torch.randn(12)) + dx = Dim("dx", min=2, max=32) + dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op + + then this leads to 2 linear equations, and a) produce_guards() is able to solve for + the unique solution of dx = 6 and specialize, and b) the export constraint solver will + raise an issue due to range constraints (a unique solution means not all values in a + range satisfy a guard) and also force specializations. + ''' + from torch.export.dynamic_shapes import Dim + + def _check_same_range(c, dim): + # returns True if c & dim are both min/max ranges with same values + return ( + self._is_dim(dim) + and ("min" in c or "max" in c) + and ( + (dim.min < 2 and c.get("min", 2) == 2) + or dim.min == c.get("min", 2) + ) # let pass if analysis min = 2 and specified min = 0/1 + and dim.max == c.get("max", int_oo) + ) + + # 1) newly introduced roots + # this part we handle adding newly introduced roots + # these arise from guards like "x.shape[0] % 3 == 0" + # leading to suggested fixes like "dx = 3*_dx" + # extract _dx, and find appropriate min/max values + # + # before, we have something like: + # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} + # we want instead: + # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} + introduced_roots: Dict[str, str] = {} # map new root -> old root + for k, c in list(results.items()): + if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim + root = next(iter(c["eq"].free_symbols)) + if str(root) not in name_to_dim: + introduced_roots[str(root)] = k + # calculate necessary min & max + modulus, remainder = sympy.polys.polytools.div(c["eq"], root) + c_min = c.get("min", 2) + min_ = math.ceil((c_min - remainder) / modulus) + c_max = c.get("max", int_oo) + max_ = math.floor((c_max - remainder) / modulus) + # create result & dim + results[str(root)] = {"min": min_, "max": max_} + name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_) + # remove old root min/max bounds + c.pop("min", None) + c.pop("max", None) + + # alter derivations that depend on old root, to unify to new root + # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 + for old_root in introduced_roots.values(): + for k, c in list(results.items()): + if ( + "eq" in c + and isinstance(c["eq"], sympy.Expr) + and str(symbol := next(iter(c["eq"].free_symbols))) == old_root + ): # derived dim with root = old_root + new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 + new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 + c["eq"] = new_expr + + # 2) root swapping + # collect all the original roots that are used for calculating values of suggested fixes + # this consists of: + # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim + # 2) {"dy": "dx + 1"} -> dx: root for suggested fix + modified_roots: Set[str] = set() + for k, c in results.items(): + if k not in name_to_dim: # _dynamo.export() may handle source directly + continue + if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1) + modified_roots.add(k) + elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2) + root = next(iter(c["eq"].free_symbols)) + assert root is not None + modified_roots.add(str(root)) + + # exclude newly introduced roots, we've already processed these + modified_roots = modified_roots.difference(introduced_roots) + + # evaluate the new value for each root + # this is now either 1) unchanged, 2) refined with a new range, + # or 3) specialized to a concrete value + modified_root_values: Dict[str, Dict[str, Any]] = {} + for root in modified_roots: + swapped_root = True + if root in results: + c = results[root] + if ( + ("min" in c or "max" in c) # range + or isinstance(c["eq"], int) # specialized + ): + # here, the original root is a root Dim or concrete value in results. + # if it is a derived dim, it is swapped, and we handle that below. + if not _check_same_range(c, name_to_dim[root]): # ignore if unchanged + modified_root_values[root] = c + swapped_root = False + + if swapped_root: + # if the original root has been swapped in results, that means the new root + # is a range (if it had specialized, the original root would have too). + # find this new root, and solve for the original root's range. + for k, c in results.items(): + if k not in name_to_dim: + continue + dim = name_to_dim[k] + if dim.__class__.__name__ == "_DerivedDim" and dim.root.__name__ == root: + # only look for min/max root, otherwise root would have specialized + if "min" in c or "max" in c: + expr = sympy.sympify(k) + s = next(iter(expr.free_symbols)) + result = { + "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type] + "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type] + } + if not _check_same_range(result, name_to_dim[root]): # ignore if unchanged + modified_root_values[root] = result + break + + # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) + # we only want to suggest fixes for the root, to avoid derived names. + # also, remove anything in modified_roots, since we either add new modified values after this, + # or have decided they are unchanged. + for k in list(results.keys()): + if k not in name_to_dim: + continue + if self._is_derived_dim(name_to_dim[k]) or k in modified_roots: + del results[k] + + # update results with modified root values + # now results has the following properties: + # - only contains original roots as keys + # - each root is now either specialized, refined, or derived from another original root + results.update(modified_root_values) + + def prettify_results( + self, + original_signature: inspect.Signature, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + constraint_violation_error=None, + forced_specializations=None, + ): + """Format a message for constraint violation erros""" + from torch.export.dynamic_shapes import _get_dim_name_mapping + if not self._dcp.source_name_to_debug_name: + # nothing to do + return "" + + def transform(s, inverse=False): + for k, v in self._dcp.source_name_to_debug_name.items(): + s = s.replace(k, v) if not inverse else s.replace(v, k) + return s + + results = defaultdict(dict) + if dynamic_shapes is None: + dynamic_shapes = {} + + def flip(op): + if op == "<=": + return ">=" + if op == ">=": + return "<=" + if op == "<": + return ">" + if op == ">": + return "<" + assert op == "==" + return op + + def relation_with_digit(expr, op, digit): + if op == "<=": + results[expr]["max"] = digit + elif op == "<": + results[expr]["max"] = digit - 1 + elif op == ">=": + results[expr]["min"] = digit + elif op == ">": + results[expr]["min"] = digit + 1 + else: + assert op == "==" + results[expr]["eq"] = digit + + # retrieve dynamic shapes + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + for s in self._static_results.union(self._dynamic_results): + t = transform(s) + if t == s: + continue + left, op, right = re.split(r"( == | <= | >= | < | > )", t) + op = op.strip() + if op == "==" and left == right: + continue + if right.isdigit(): + relation_with_digit(left, op, int(right)) + elif left.isdigit(): + relation_with_digit(right, flip(op), int(left)) + else: + assert op == "==", t + results[left]["eq"] = sympy.sympify(right) + + # order forced specializations based on name + forced_specializations = { + k: forced_specializations[k] + for k in sorted( + forced_specializations.keys(), + key=lambda x: x.split(" = ")[1], + ) + } + + buf = "" + if forced_specializations: + debug_names = set() + for k in forced_specializations: + dim = name_to_dim[k.split(" = ")[0]] + if self._is_derived_dim(dim): + debug_names.add(dim.root.__name__) + else: + debug_names.add(dim.__name__) + + buf += ( + f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + ) + for s, val in forced_specializations.items(): + buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n" + + self._process_derived_dim_roots(results, name_to_dim) + + dims = [] + others = [] + + # order results by source name + results = { + k: results[k] for k in sorted( + results.keys(), + key=lambda x: transform(x, inverse=True), + ) + } + for k, c in results.items(): + if "eq" in c: + other = c["eq"] + if isinstance(other, int): + others.append(f"{k} = {other}") + elif _is_supported_equivalence(other): + others.append(f"{k} = {other}") + else: + min_ = c.get("min", None) + if min_ == 2: + min_ = None + max_ = c.get("max", None) + if min_ is not None and max_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") + elif min_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_})") + elif max_ is not None: + dims.append(f"{k} = Dim('{k}', max={max_})") + else: + dims.append(f"{k} = Dim('{k}')") + + # results will get filtered out if no new suggestions, + # this can happen if guards are too complex. + # in that case don't suggest fix + if dims or others: + buf += "\nSuggested fixes:\n " + buf += "\n ".join(dims + others) + + return buf + + +TLS = threading.local() + + +@dataclass(frozen=True) +class ShapeEnvSettings: + """ + Encapsulates all shape env settings that could potentially affect + FakeTensor dispatch. Used when creating dispatch cache keys. + """ + + allow_scalar_outputs: bool + allow_dynamic_output_shape_ops: bool + assume_static_by_default: bool + specialize_zero_one: bool + duck_shape: bool + prefer_deferred_runtime_asserts_over_guards: bool + allow_complex_guards_as_runtime_asserts: bool + + +class ShapeEnv: + # This is a wrapper over the actual __init__ function. + # + # Where to add a new constructor parameter to ShapeEnv? + # ===================================================== + # This __init__ function should be used only for parameters related to event recording. + # These are parameters that we don't wish to pass down the road to new ShapeEnv instances + # created from replaying events. + # + # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event + # recording, do so in the _init function. + def __init__( + self, *, + should_record_events: Optional[bool] = None, + tracked_fakes: Optional[List[Any]] = None, + **kwargs + ) -> None: + self._init(**kwargs) + + # Disable event recording when replaying. + kwargs["should_record_events"] = False + + from torch.fx.experimental.validator import translation_validation_enabled + self._translation_validation_enabled = translation_validation_enabled() + + # If not specified, enable event recording if both: + # - Translation validation is on + # - Translation validation bisection is not disabled + self.should_record_events = ( + should_record_events + if should_record_events is not None + else ( + self._translation_validation_enabled + and not config.translation_validation_no_bisect + ) + ) + + # Enable event recording check if both: + # - It should record events + # - The recording check is enabled + self.check_recorded_events = ( + self.should_record_events and config.check_shape_env_recorded_events + ) + + # This will make sure we only record the top-level function call. + self.is_recording = not self.should_record_events + # Keep track of the list of tracked fakes. + self.tracked_fakes = tracked_fakes + # List of events for reconstructing ShapeEnv at arbitrary points in time. + self.events: List[ShapeEnvEvent] = ( + [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else [] + ) + + # FakeTensor per-ShapeEnv operation cache. This is used for caching + # operations that contain symbolic shapes which have guards on the + # ShapeEnv (so are ShapeEnv-dependent). + # + # NOTE: It's important that SymNodes in this cache have their ShapeEnv + # stripped otherwise you end up with cycles which can only be cleaned + # with the GC. + self.fake_tensor_cache: Dict[torch._subclasses.fake_tensor._DispatchCacheKey, + torch._subclasses.fake_tensor._DispatchCacheEntry] = {} + + # Pro-tip: if you add new field to ShapeEnv, this affects some accept + # tests. Accept their output with: + # + # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal + # + def _init( + self, *, + allow_scalar_outputs=True, + allow_dynamic_output_shape_ops=True, + # NB: These are legacy configuration that help us make good choices + # when the constraint/dynamic dims are not explicitly passed to us. + # Ideally we will fix all call sites to be explicit and not have + # implicit choices, but this apparently was pretty involved. + assume_static_by_default=False, + # Note - On 0/1 specialization + # + # The following options affect decisions we make about eager + # specialization. Disabling them will increase trace time (as we do + # more symbolic reasoning) and can also harm the quality of generated + # code (because inductor may not be able to specialize for bounds + # being equal--although if we later respecialize because of a guard, + # your code may be just as good as it was before.) + # + # When True, eagerly specialize input sizes which have 0/1. + specialize_zero_one=True, + # When True, assume input sizes which have the same size are + # symbolically equal. + duck_shape: Optional[bool] = None, + # For debugging + co_fields=None, + # When True, whenever safe, we will generate a deferred runtime assert + # instead of a guard whenever we know that an expression must be True, + # otherwise it would be an error, even for backed SymInts (where we + # could ostensibly unconditionally generate guards). This is useful + # for export, where preventing "error checking" sizes from showing up + # in guards is helpful, since these guards in some sense are overly + # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 + prefer_deferred_runtime_asserts_over_guards=False, + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + allow_complex_guards_as_runtime_asserts=False, + # XXX Add any new settings that could affect FakeTensor evaluation + # to: torch._subclasses.fake_tensor._ShapeEnvSettings + ): + if duck_shape is None: + duck_shape = config.use_duck_shape + + self.settings = ShapeEnvSettings( + # Not directly used by ShapeEnv; indirectly used by FakeTensor + allow_scalar_outputs=allow_scalar_outputs, + allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops, + # End + assume_static_by_default=assume_static_by_default, + specialize_zero_one=specialize_zero_one, + duck_shape=duck_shape, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + ) + + self.guards: List[ShapeGuard] = [] + # Maps symbolic ints to their original concrete values + # Currently populated from tensors + self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + # Like var_to_val, but only set when propagate_real_tensors is on. + # Used as last resort to avoid GuardOnDataDependent error + self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + # Maps symbolic ints to their min/max range. These ranges + # are conservative: the int MUST fall in the range, but the + # range may contain ints which may not actually appear in + # practice + self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {} + self.source_name_to_debug_name: Dict[str, str] = {} + self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {} + self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {} + # Maps from sympy ints to expressions representing them + # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) + self.replacements: Dict[sympy.Symbol, sympy.Expr] = {} + self.unbacked_renamings: Dict[sympy.Symbol, sympy.Symbol] = {} + # Set holds a % b expressions that evaluate to 0. + self.divisible: Set[sympy.Expr] = set() + # Set that holds "size-like" symbols. When we perform + # "size-oblivious" tests, these can be assumed to be >= 2. + self.size_like: Set[sympy.Symbol] = set() + # Duck-shaping says that if two input tensors have the same size, + # they get assigned the same symbolic variable + self.val_to_var: Dict[int, sympy.Expr] = {} + if specialize_zero_one: + self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)} + self.unbacked_symfloat_counter = itertools.count() + self.unbacked_symint_counter = itertools.count() + # Similar to guards, but these MUST evaluate to true and can + # only be evaluated at runtime midway through (i.e., they always + # involve unbacked symints) + # + # For efficiency reasons, we index in the following way. Suppose you have + # a runtime assert i0 + i1 <= s1. We pick the most recently allocated + # symbol in the source expression and add the assert to the list for + # that symbol e.g., {i1: [i0 + i1 <= s1]}. + # + # We access the runtime asserts in two situations: + # + # - When we are guarding on an expression, we will attempt to + # statically evaluate it, in case the unbacked SymInts can + # simplify away. If we have a runtime assert, we may be able + # to discharge the guard entirely. We only need to attempt + # runtime asserts that mention freevars of the expression in + # question. + # + # - When we are performing codegen (in Inductor for eager, or + # when finalizing the export FX graph), we need to know what + # extra runtime asserts to insert. Whenever an unbacked + # SymInt comes into scope, all runtime asserts involving it + # become eligible for insertion (so long as all of their other + # free unbacked symbols are also in scope). We technically + # can handle any choice of key by kicking inexpressible asserts + # to the next unbacked symbol to wait on, but if we choose the + # latest key, an assert will only show up at the moment when + # we can actually codegen it. + self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {} + # This exists so we can efficiently invalidate the cache (it's used as + # part of the cache key); otherwise we'd have to iterate through + # deferred_runtime_asserts to compute its length + self.num_deferred_runtime_asserts = 0 + self.log = log + self.log.debug("create_env") + self.frozen = False + self.runtime_asserts_frozen = False + self.dim_constraints: Optional[DimConstraints] = None + self.counter = collections.Counter() + # Mapping from sympy.Symbol to the number of guards which mention this + # symbol + self.symbol_guard_counter = collections.Counter() + # A selection of important fields on co_field; solely used for + # signpost_event + self.co_fields = co_fields if co_fields else {} + + # Whenever we allocate a fresh unbacked Symbol, we add it to this + # pending list. Unbacked symbol allocation can occur at unpredictable + # points during meta tensor propagation, but at some point, the we + # have to know what the binding site for an unbacked symbol is, and + # this is computed when we actually place the node in the graph. The + # important thing is that we always actually handle every unaccounted + # for unbacked symbol, so this list helps us keep track of them and + # then make sure they are all accounted for. + # + # We could potentially give rise to errors earlier by lexically + # scoping when we do propagation, and only allowing unbacked symbols + # to be allocated at this point in time. However this is inconvenient + # to do in Dynamo, because fake tensor propagation is far from when we + # analyze binding sites (set_example_value), so we do it in a more + # mutatey way. + # + # NB: fresh unbacked symbols NEVER get substitutions applied to them, + # they are binding sites! + self.pending_fresh_unbacked_symbols: List[sympy.Symbol] = [] + + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + + # Cache for FX nodes. + # Maps an already built node a tuple of: + # 1. node's target + # 2. list of arguments + # This drastically reduces the size of the FX graph, avoiding + # duplicated nodes. + self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: Dict[str, sympy.Symbol] = {} + + # Suppose you want to replace an unbacked symbol with another + # unbacked symbol. This is error prone because you can cause + # references to unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. + # + # To control for this, we track the order unbacked symbols + # were allocated, and only allow substitutions if they respect + # the dependency from this order; an unbacked symbol can only + # be substituted with unbacked symbols that come before it in the + # order. + # + # This also imposes an ordering on the unbacked symbol binding + # sites themselves: you are not allowed to reorder unbacked symbol + # bindings. At the moment, this is not tracked, but we potentially + # could track this at the IR level using a higher order operator + # with something like effect token tracking. + self.unbacked_alloc_order: Dict[sympy.Symbol, int] = {} + + from torch.fx.experimental.validator import translation_validation_enabled + self._translation_validation_enabled = translation_validation_enabled() + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import TranslationValidator + + self.validator = TranslationValidator() + self.graph = torch.fx.Graph() + # Create an output graph and start inserting before that. + # This is needed when 'deepcopy'-ing this object. + self.graph.inserting_before(self.graph.output(None)) + + # Mapping of each node name to the node itself. + # + # This is useful for matching an FX node from a recorded ShapeEnv.graph + # to the FX node of the ShapeEnv we are running the event on. + # + # Whenever you add a node to self.graph, you must add a mapping to this + # variable. Otherwise, the built FX graph on the replayed ShapeEnv will + # not be valid. + self.name_to_node: Dict[str, torch.fx.Node] = {} + + @property + def allow_scalar_outputs(self): + return self.settings.allow_scalar_outputs + + @property + def allow_dynamic_output_shape_ops(self): + return self.settings.allow_dynamic_output_shape_ops + + @property + def assume_static_by_default(self): + return self.settings.assume_static_by_default + + @property + def specialize_zero_one(self): + return self.settings.specialize_zero_one + + @property + def duck_shape(self): + return self.settings.duck_shape + + @property + def prefer_deferred_runtime_asserts_over_guards(self): + return self.settings.prefer_deferred_runtime_asserts_over_guards + + @property + def allow_complex_guards_as_runtime_asserts(self): + return self.settings.allow_complex_guards_as_runtime_asserts + + def check_equal(self, other: "ShapeEnv") -> None: + """Compare another ShapeEnv for equivalence + """ + # ShapeEnv fields that are not relevant for the outcome of + # ShapeEnv.produce_guards call: + # - Debugging variables + # - Translation validation related variables + # - Events recording related variables + non_state_variable_names = ( + "counter", + "log", + "var_to_stack", + "fx_node_cache", + "graph", + "validator", + "check_recorded_events", + "should_record_events", + "is_recording", + "tracked_fakes", + "events", + "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", + "dim_constraints", + ) + + # Mapping of the value of each to-be-compared field into the values that + # should actually be compared. + # + # You should modify this if, for example, the field that holds state and + # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) + # and the stack when it was added to the set of guards. In order to compare + # it, we throw away the stack information. + def map_value(key: str, value: Any) -> Any: + if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"): + from copy import copy + + # For itertools.count(), we compare the next integer returned + # by the count iterators. Not that we need to copy the iterator + # first. Otherwise we are mutating the object. + return next(copy(value)) + elif key == "guards": + # Transform the list of ShapeGuard into a list of expressions. + return [g.expr for g in value] + elif key == "deferred_runtime_asserts": + # Transform the list of RuntimeAsserts into a list of expressions. + return {s: [ra.expr for ra in ras] for s, ras in value.items()} + elif key == "name_to_node": + # Compare just the set of keys is the same. + return set(value.keys()) + elif key in ("symbol_guard_counter", "pending_fresh_unbacked_symbols", "fake_tensor_cache"): + # Skip this for comparisons + return None + return value + + shape_env_check_state_equal(self, other, non_state_variable_names, map_value) + + def _snapshot_tracked_fakes(self) -> Optional[List[Any]]: + if self.tracked_fakes is None: + return None + + from torch._dynamo.variables.builder import TrackedFake + + def maybe_transform_fake(fake: TrackedFake): + inner_fake = fake.fake \ + if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) \ + else FakeTensorMeta.from_fake(fake.fake) + # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a + # FakeTensorMeta for two reasons: + # 1. this is all the information we need when recording ShapeEnvEvents. + # 2. it works even if each TrackedFake changes its metadata. + return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] + + return [maybe_transform_fake(fake) for fake in self.tracked_fakes] + + def _last_event_index(self) -> int: + return len(self.events) - 1 + + @contextmanager + def _recording(self): + self.is_recording = True + try: + yield + finally: + self.is_recording = False + + @record_shapeenv_event() + def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr): + self._set_replacement(orig_s, new_s, "eliminate_unbacked") + + @record_shapeenv_event() + def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: + """Used only when propagate_real_tensors; registers a value for an + unbacked symbol, which can be used last resort to resolve hints.""" + self.unbacked_var_to_val[k] = sympy.sympify(v) + + # Unlike set_replacement, this records a shapeenv event + @record_shapeenv_event() + def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): + assert isinstance(orig_s, sympy.Symbol), orig_s + assert isinstance(new_s, sympy.Symbol), new_s + assert free_unbacked_symbols(new_s), new_s + assert free_unbacked_symbols(orig_s), orig_s + dest = self.replacements.get(orig_s) + assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" + self._set_replacement(orig_s, new_s, "rename_unbacked_to") + self.unbacked_renamings[orig_s] = new_s + if dest is not None: + self._set_replacement(new_s, dest, "rename_unbacked_to_dest") + + @record_shapeenv_event() + def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None): + if min is None: + min = 0 + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + self.size_like.add(a) + + @record_shapeenv_event() + def _constrain_range(self, a: sympy.Expr, min: int, max: int): + if isinstance(a, sympy.Integer): + if not (min <= int(a) <= max): + raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") + return + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + if isinstance(a, sympy.Symbol): + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + + @record_shapeenv_event() + def _constrain_unify(self, a, b): + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + # Update Feb 2024: this is extra important to do, this doesn't handle + # unbacked replacements properly nor does it generate deferred runtime + # asserts + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert b.node.shape_env is self + self.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert a.node.shape_env is self + if not isinstance(b, SymInt): + self.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + new_var = self._find(a.node.expr) + self.replacements[b.node.expr] = new_var + + def _ignore_fresh_unbacked_symbols_tls(self): + return getattr(TLS, "ignore_fresh_unbacked_symbols", False) + + @record_shapeenv_event() + def _ignore_fresh_unbacked_symbols_enter(self): + TLS.ignore_fresh_unbacked_symbols = True + + @record_shapeenv_event() + def _ignore_fresh_unbacked_symbols_exit(self): + TLS.ignore_fresh_unbacked_symbols = False + + @contextmanager + def ignore_fresh_unbacked_symbols(self): + """ + Indicates that the newly allocated unbacked SymInts are being + discarded + """ + self._ignore_fresh_unbacked_symbols_enter() + try: + yield + finally: + self._ignore_fresh_unbacked_symbols_exit() + + @record_shapeenv_event() + def freeze(self): + """Freeze this ShapeEnv to stop accumulating guards + + A frozen ShapeEnv will ignore any further guards generated on it and + only emit a warning which may lead to accuracy problems. + """ + self.frozen = True + + @record_shapeenv_event() + def freeze_runtime_asserts(self): + """Freeze this ShapeEnv to stop adding deferred runtime asserts. + + We will error if you try to install a new runtime assert when it is + frozen. This would indicate a lowering violation, or perhaps something + we know statically is already True but we are checking it again in a way + that is not clearly dischargeable. + """ + # self.prefer_deferred_runtime_asserts_over_guards = False + self.runtime_asserts_frozen = True + + def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: + if not self._translation_validation_enabled: + return None + srcname = source.name() + if source not in self.source_to_symbol: + self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) + return self.source_to_symbol[srcname] + + def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None: + if self._translation_validation_enabled: + self.validator.add_var(symbol, type) + + def _add_target_expr(self, expr) -> None: + if self._translation_validation_enabled: + self.validator.add_target_expr(expr) + + def _add_assertion(self, expr) -> None: + if self._translation_validation_enabled: + self.validator.add_assertion(expr) + + def _check_translation_validate(self) -> None: + if self._translation_validation_enabled: + self.validator.validate() + + @record_shapeenv_event() + def _create_fx_call_function( + self, + op: Callable, + args: Tuple, + ) -> Tuple[Optional[torch.fx.Node], bool]: + # Cache this tuple in order to avoid duplicated nodes. + node_key = (op, args) + # Flags whether the returned node was cached or not. + fresh = False + + if self._translation_validation_enabled and node_key not in self.fx_node_cache: + + # Presence of None in the arguments implies that we should ignore this operation. + if any(a is None for a in args): + # We check if we are not mixing SymNode that should not be ignored + # (fx_node is not None) with those that should (fx_node is None). + assert all(not isinstance(a, torch.fx.Node) for a in args) + return None, fresh + + fresh = True + + # If translation validation is enabled, all arguments must have its + # own FX node. + assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}" + node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) + self.name_to_node[node.name] = node + + return self.fx_node_cache.get(node_key, None), fresh + + def _create_fx_placeholder_and_z3var( + self, + symbol: sympy.Symbol, + type: Type, + ) -> Optional[torch.fx.Node]: + if not self._translation_validation_enabled: + return None + + node_key = (self.graph.placeholder, (symbol,)) + + # Check if we haven't added this symbol already. + # If so, skip the placeholder creation, as it + # generates invalid Python code. + if node_key not in self.fx_node_cache: + # Add a Z3 variable according to 'type'. + self._add_z3var(symbol, type) + # Create the FX placeholder out of a mangled name. + mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name)) + node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) + self.name_to_node[node.name] = node + # Attach the 'symbol' to the placeholder so that we can retrieve + # the Z3 variable later. + node.meta["symbol"] = symbol + + return self.fx_node_cache[node_key] + + def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: + if self._translation_validation_enabled and node is not None: + self.name_to_node.pop(node.name) + self.graph.erase_node(node) + + def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: + from torch._dynamo.utils import get_current_node + + if self.should_record_events: + node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() + node.meta[CURRENT_NODE_KEY] = get_current_node() + + def _suppress_guards_tls(self): + return getattr(TLS, "suppress_guards", False) + + @record_shapeenv_event() + def _suppress_guards_enter(self): + TLS.suppress_guards = True + + @record_shapeenv_event() + def _suppress_guards_exit(self): + TLS.suppress_guards = False + + @contextmanager + def suppress_guards(self): + """Context manager to ignore all guards generated inside""" + self._suppress_guards_enter() + try: + yield + finally: + self._suppress_guards_exit() + + def _get_key(self): + """ + Defines the current "state" of the guards we've accumulated in this ShapeEnv. + Determines when we need to invalidate our cache + """ + return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts, len(self.unbacked_var_to_val)) + + def _update_version_counter(self): + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + + def _produce_dyn_sizes(self, + ex_size: Sequence[int], + source: Source, + symbolic_context: SymbolicContext + ) -> List[sympy.Expr]: + return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context) + + def _produce_dyn_sizes_from_int_tuple(self, + tensor_size: Tuple[int], + source: Source, + symbolic_context: SymbolicContext, + ) -> List[sympy.Expr]: + assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}" + from torch._dynamo.source import TensorPropertySource, TensorProperty + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes + constraint_dims = symbolic_context.constraint_sizes + size = [] + for i, val in enumerate(tensor_size): + size.append(self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.SIZE, i), + dynamic_dims[i], + constraint_dims[i], + symbolic_context=symbolic_context + )) + return size + + def create_symbolic_sizes_strides_storage_offset( + self, + ex: torch.Tensor, + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ): + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size()) + ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride()) + ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset()) + + return self._create_symbolic_sizes_strides_storage_offset( + ex_size, + ex_stride, + ex_storage_offset, + [_is_dim_dynamic(ex, i) for i in range(ex.dim())], + source, + symbolic_context=symbolic_context, + ) + + # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). + # We create symbols in shape_env using the backed hints behind SymInt. + + # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. + # produce_guards will trigger specializations on the outer stuff + + # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # + # It's probably good for now but it's important to note that this approach has implications for + # the original shape_env when checking guards in different order. + + # Example: + # --------- + # Consider a function "opt_f" as shown below: + + # @torch.compile() + # def opt_f(x: bool, y: Tensor): + # if x == True: + # return y + torch.randn([4]) + # else: + # return y + # Depending on the sequence of calls, we might install two different sets of guards: + + # 1. opt_f(False, y): + # - "x == False" (always works for any size y) + + # 2. opt_f(True, y): + # - Triggers recompilation and results in guards like: + # - "x == True and y.size(0) == 4" + # - (or "y.size(0) == 4 and x == True") + + # The order of checking the guards matters. In this specific example: + # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, + # we may have an unnessary shape speciliazation for y. + def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int: + assert isinstance(maybe_sym, (int, torch.SymInt)) + if is_symbolic(maybe_sym): + assert maybe_sym.node.shape_env is not self, \ + "expect the symbol is created from an shape env other than current one." + return maybe_sym.node.require_hint() + return maybe_sym + + @record_shapeenv_event() + def _create_symbolic_sizes_strides_storage_offset( + self, + ex_size: Sequence[int], + ex_stride: Sequence[int], + ex_storage_offset: int, + is_dim_dynamic: Sequence[bool], + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ): + dim = len(ex_size) + + # Reimplement the legacy behavior + if symbolic_context is None: + constraint_sizes = [None] * dim + constraint_strides = [None] * dim + dynamic_dims = [] + dynamic_strides = [] + for i in range(dim): + # NB: This is encapsulation breaking! Legacy behavior was + # bad. + if is_dim_dynamic[i]: + r = DimDynamic.DYNAMIC + elif self.assume_static_by_default: + r = DimDynamic.STATIC + else: + r = DimDynamic.DUCK + dynamic_dims.append(r) + dynamic_strides.append(r) + dynamic_dims = [DimDynamic.DUCK] * dim + dynamic_strides = [DimDynamic.INFER_STRIDE] * dim + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=dynamic_dims, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + ) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_sizes = symbolic_context.constraint_sizes + constraint_strides = symbolic_context.constraint_strides + dynamic_sizes = symbolic_context.dynamic_sizes + dynamic_strides = symbolic_context.dynamic_strides + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context + # decision here where if all sizes are static, we are going to + # specialize all of the inner strides/offset too. We don't have to + # do this, and arguably we should ALWAYS allow for dynamic offset, + # this is cheap. + # TODO: This should be DYNAMIC, using DUCK for BC + dynamic_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_sizes) else DimDynamic.DUCK + are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes) + + assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(constraint_sizes) == dim + assert len(constraint_strides) == dim + + from torch._dynamo.source import TensorPropertySource, TensorProperty + size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context) + stride: List[Optional[sympy.Expr]] = [None] * len(size) + for i, val in enumerate(ex_stride): + if val in (0, 1): + stride[i] = sympy.Integer(val) + while any(x is None for x in stride): + candidates = { + ex_size[i] * ex_stride[i]: size[i] * stride[i] + for i in range(len(size)) + if stride[i] is not None and ex_stride[i] >= 0 + } + + # iterate over unbound strides in sorted order + def _nested_int_aware_sort(tup): + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0]) + else (0, *tup) + ) + val_list = sorted( + [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None], + key=_nested_int_aware_sort, + ) + for _, i in val_list: + # Set stride to a candidate only for DimDynamic.INFER_STRIDE + if stride[i] is None and dynamic_strides[i] == DimDynamic.INFER_STRIDE and ex_stride[i] in candidates: + stride[i] = candidates[ex_stride[i]] + candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i] + + if any(x is None for x in stride): + # bind the smallest unbound stride to a new variable + val, i = min( + [ + (ex_stride[i], i) + for i in range(len(stride)) + if stride[i] is None + ], key=_nested_int_aware_sort + ) + # Set INFER_STRIDE to STATIC or DUCK depending on sizes + dyn_stride = dynamic_strides[i] + if dynamic_strides[i] == DimDynamic.INFER_STRIDE: + dyn_stride = DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK + stride[i] = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.STRIDE, i), + dynamic_dim=dyn_stride, + constraint_dim=constraint_strides[i], + symbolic_context=symbolic_context, + ) + assert all(x is not None for x in stride) + + sym_sizes = [ + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + ) + for i, (sym, hint) in enumerate(zip(size, ex_size)) + ] + sym_stride = [] + for i, stride_expr in enumerate(stride): + # NB: Don't duck size the stride; instead use the expression + # we computed + assert stride_expr is not None + sym_stride.append(self.create_symintnode( + stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i))) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=dynamic_offset, + constraint_dim=None, + symbolic_context=symbolic_context + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)) + return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset + + @record_shapeenv_event() + def create_symintnode( + self, + sym: "sympy.Expr", + *, + hint: Optional[int], + source: Optional[Source] = None, + ): + """Create a SymInt value from a symbolic expression + + If you know what the current hint value of the SymInt to be created + is, pass it into hint. Otherwise, pass None and we will make our best + guess + + """ + source_name = source.name() if source else None + + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + if isinstance(sym, sympy.Integer): + if hint is not None: + assert int(sym) == hint + out = int(sym) + else: + # How can this occur? When we mark_unbacked, we end up with a real + # tensor that has hints for all sizes, but we MUST NOT create a + # SymNode with a hint, because we're hiding the hint from our eyes + # with the unbacked Symbol. And in fact, the hint compute may be + # inconsistent with size oblivious tests. + if free_unbacked_symbols(sym): + hint = None + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_symfloatnode( + self, + sym: "sympy.Expr", + *, + hint: Optional[int], + source: Optional[Source] = None, + ): + """Create a SymFloat value from a symbolic expression""" + source_name = source.name() if source else None + + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + if isinstance(sym, sympy.Float): + if hint is not None: + assert float(sym) == hint + out = float(sym) + else: + # You could give this the same treatment as SymInt above if + # you supported mark_unbacked on a float, but it's a kind of + # strange thing to do though because floats don't get 0/1 + # specialization anyway + if free_unbacked_symbols(sym): + assert hint is None, sym + out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): + """Create a SymInt wrapping a new unspecified symbol""" + return self.create_symintnode( + self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + def create_symboolnode(self, sym: "sympy.Expr"): + """Create a SymBool object from a sympy boolean expression""" + # This function is only being used in serialization, so we do not track it + # for validation. + return SymBool(SymNode(sym, self, bool, None)) + + def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges): + is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',') + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + log.info( + "%s %s [%s, %s]%s (%s)%s", + prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug + ) + + @record_shapeenv_event() + def create_unbacked_symfloat(self): + """Create a symbolic float without a hint value + """ + symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter)) + self.counter["create_unbacked_symbol"] += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr) + + return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node)) + + @record_shapeenv_event() + def create_unbacked_symint(self): + """Create a symbolic integer without a hint value + """ + symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr) + + return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node)) + + def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: + """Check if a sympy symbol matches the naming convention for unbacked symbols + """ + return symbol_is_type(symbol, SymT.UNBACKED_INT) + + @record_shapeenv_event() + def create_unbacked_symbool(self): + """Create a symbolic boolean without a hint value + """ + symbol: sympy.Symbol = make_symbol(SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) + + self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr) + + return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node)) + + @record_shapeenv_event() + def create_unspecified_symbol( + self, + val: Union[int, SymInt, float, SymFloat], + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + ) -> "sympy.Expr": + """Create a symbol with an unspecified value + + Compared to standard symbols we do not assume the value is positive, + nor do we specialze on zero or one values. + """ + # 'positive' is None for unspecified symbols, since we can't + # assume that it will be neither positive nor negative. + + # We don't want to specialize zero one val for unspecified symbol + # so that we can always get a new symbol despite val. + return self.create_symbol( + val, + source, + dynamic_dim, + constraint_dim, + positive=None, + do_not_specialize_zero_one=True, + symbolic_context=None) + + @record_shapeenv_event() + def create_symbol( + self, + val: int, + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + positive: Optional[bool] = True, + do_not_specialize_zero_one: bool = False, + symbolic_context=None, + ) -> "sympy.Expr": + """Create a new symbol which is tracked by this ShapeEnv + """ + # check if constraint_dim is actually static integer + if isinstance(constraint_dim, StrictMinMaxConstraint) and constraint_dim.vr.lower == constraint_dim.vr.upper: + dynamic_dim = DimDynamic.STATIC + if constraint_dim.vr.lower != val: + raise ConstraintViolationError( + f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " + f"for {source.name()}" + ) + if symbolic_context: + symbolic_context.dynamic_sizes[source.idx] = dynamic_dim + symbolic_context.constraint_sizes[source.idx] = None + constraint_dim = None + + # see note [Tensor Fakification and Symbol Caching] + source_name = source.name() + if (isinstance(symbolic_context, StatefulSymbolicContext) + and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache): + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} + + if (isinstance(symbolic_context, StatefulSymbolicContext) + and source_name + and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])): + return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] + + if dynamic_dim is DimDynamic.SIZE_LIKE_UNBACKED: + out = self.create_unbacked_symint().node.expr + self._constrain_range_for_size(out) + # TODO: maybe put the hint somewhere + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out + return out + + if do_not_specialize_zero_one: + specialize_zero_one = False + else: + specialize_zero_one = self.specialize_zero_one + + assert isinstance(source, Source), f"{type(source)} {source}" + assert not (positive and val < 0), f"positive set for negative value: {val}" + # It's always sound to allocate a symbol as DYNAMIC. If the user + # constrained the symbol, force the symbolic_context to DYNAMIC, because our + # constraint code will do weird stuff if, e.g., it's duck shaped + if constraint_dim is not None: + dynamic_dim = DimDynamic.DYNAMIC + + if dynamic_dim is DimDynamic.STATIC: + out = sympy.Integer(val) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out + return out + + elif dynamic_dim is DimDynamic.DUCK: + # duck_shape can be used to globally turn off duck shaping, even + # if it was requested + duck = self.duck_shape + elif dynamic_dim is DimDynamic.DYNAMIC: + duck = False + else: + raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") + + if val in (0, 1) and specialize_zero_one: + r = self.val_to_var[val] + elif not duck or val not in self.val_to_var: + # If we're not duck shaping, we always create a new symbol + # Even if we're duck shaping, if we haven't seen this particular + # value before, we also create a new symbol + if type(val) is int or is_nested_int(val): + sympy_expr = make_symbol(SymT.SIZE, len(self.var_to_val), positive=positive, integer=True) + else: + sympy_expr = make_symbol(SymT.FLOAT, len(self.var_to_val), positive=positive, real=True) + # We always associate vars to vals + if isinstance(val, int): + self.var_to_val[sympy_expr] = sympy.Integer(val) + elif isinstance(val, float): + self.var_to_val[sympy_expr] = sympy.Float(val) + else: + # Only used for jagged layout nested tensors + self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff()) + + # Do the appending later, because we always want to populate this + self.var_to_sources[sympy_expr] = [] + # Create a Z3 variable for the new symbol. + self._add_z3var(sympy_expr, int) + + if duck: + # Make sure to reuse this symbol for subsequent duck shaping + self.val_to_var[val] = sympy_expr + + if isinstance(val, int): + if positive: + # Add assertions for the newly created symbols + self._add_assertion(sympy_expr > 1) + + # Apply default range, which assumes not zero-one + self.var_to_range[sympy_expr] = self._default_value_range() + else: + self.var_to_range[sympy_expr] = self._default_unspecified_value_range() + + # Small performance optimization: if we have a min-max constraint, + # we can proactively narrow to that range + if isinstance(constraint_dim, StrictMinMaxConstraint): + assert not duck + self.var_to_range[sympy_expr] &= constraint_dim.vr + + vr = self.var_to_range[sympy_expr] + assert vr.is_int + + if val not in vr: + raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") + + range_str = f"[{vr.lower}, {vr.upper}]" + elif isinstance(val, float): + self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) + range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float + else: + # Skip var_range logic for SingletonInt + # Only used for jagged layout nested tensors + range_str = "" + + r = sympy_expr + + is_debug = ( + config.extended_debug_create_symbol is not None and + str(sympy_expr) in config.extended_debug_create_symbol.split(',') + ) + maybe_more_info = "" + if not is_debug: + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}"' + ) + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "create_symbol %s = %s for %s %s%s (%s)%s%s", + sympy_expr, val, source.name(), range_str, + maybe_user_loc, format_frame(fsummary), maybe_more_info, maybe_extra_debug, stack_info=is_debug + ) + + self.counter["create_symbol"] += 1 + else: + # This implements duck-shaping: input sizes that match are assigned + # the same symint + r = self.val_to_var[val] + self.log.debug("create_symbol %s duck sized %s", r, source.name()) + + if isinstance(r, sympy.Symbol): + r_sources = self.var_to_sources[r] + r_sources.append(source) + if not source.is_ephemeral() and r_sources[0].is_ephemeral(): + # prefer non-ephemeral source first since it may be guarded on later + r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] + + # This ensures we get zeros in symbol_guard_counts, which makes + # some queries simpler (since we will accumulate mass on 0 this + # way) + self.symbol_guard_counter[r] = 0 + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r + return r + + def add_var_to_val(self, expr: sympy.Symbol, val: int): + """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) + assert expr not in self.var_to_val, f"{expr} already exists" + self.var_to_val[expr] = sympy.Integer(val) + + def _debug_name(self, source): + src_name = source.name() + return self.source_name_to_debug_name.get(src_name, src_name) + + def _render_range_for_constraint_violation(self, source, c): + if isinstance(c, StrictMinMaxConstraint): + lower, upper = c.vr.lower, c.vr.upper + default = self._default_value_range() + if lower <= default.lower: + lower = None + if upper >= default.upper: + upper = None + c_render = f"{self._debug_name(source)} = {source.name()} in the specified range" + if lower is not None and upper is not None: + c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" + elif lower is None and upper is not None: + c_render += f" {self._debug_name(source)} <= {upper}" + elif lower is not None and upper is None: + c_render += f" {lower} <= {self._debug_name(source)}" + return c_render + return c.render(source) + + def produce_guards( + self, + placeholders, + sources, + source_ref=lambda n: n.name(), + *, + guards: List[ShapeGuard] = None, + input_contexts: Optional[DimList[SymbolicContext]] = None, + # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). + # (See docs on EqualityConstraint for details of the encoding.) + equalities_inputs: Optional[EqualityConstraint] = None, + _simplified=False, + # Indicates if we should produce guards for known static values. + ignore_static=True, + ) -> List[str]: + """ + Generates a list of guards strings which, when evaluated in a context that + defines tensors for all the sources, returns True or False depending + on if the guards in the list evaluated to True or not. Primarily used by Dynamo, + but this is also helpful for manual testing of guards (see + evaluate_guards_for_args) + + For convenience in testing, a source is allowed to be a str, + in which case we will assume it is a LocalSource + + simplified lets you omit duck sizing, equality and 0/1 guards. + This is useful for testing when you don't care about the boilerplate + guards, and it may be helpful for user output too (be careful though; + some equality guards are nontrivial! It would be nice to get simplified + output to print them too). It's private because it's not + intended for normal use + """ + self.log.info("produce_guards") + + # Check if we get to the same ShapeEnv state by replaying the recorded events. + # This will create a new ShapeEnv instance, and call all recorded function + # calls on this new instance. Finally, it will check whether this new instance + # has equal state. + # + # It's important that we do it in the begining of this function, since it modifies + # self.dim_constraints through its execution. Changes that happen in this method + # aren't interesting, since this is the function call we wish to reproduce at the + # end. If we wish to simply reproduce ShapeEnv instances even after this call, + # this method should also be recorded. + if self.check_recorded_events: + shape_env = replay_shape_env_events(self.events) + self.check_equal(shape_env) + + assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})" + Tensorlike = (torch.Tensor, FakeTensorMeta) + + def _create_no_constraints_context(t): + return StatelessSymbolicContext( + # Ignored; only the constraints part is relevant below. + dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(), + constraint_sizes=[None] * t.dim(), + constraint_strides=[None] * t.dim() + ) + + # Expand optional inputs, or verify invariants are upheld + if input_contexts is None: + input_contexts = [ + _create_no_constraints_context(t) if isinstance(t, Tensorlike) + else None for t in placeholders + ] + else: + assert len(input_contexts) == len(placeholders) + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): + if isinstance(t, Tensorlike): + if context is None: + input_contexts[i] = _create_no_constraints_context(t) + else: + assert isinstance(t, (SymInt, int, SymFloat, float)) + assert not isinstance(context, list) + + # It took a lot of sweat to figure out the algorithm here. Let's + # explain how it works. + # + # The ShapeEnv lifecycle looks something like this: + # + # - For each input, you either generate a fresh Sympy symbol (s0) to + # represent its value (a binding site), or you reuse some + # preexisting symbol or expression, skipping the symbol allocation + # (e.g., duck sizing to a preexisting symbol, or expressing a + # stride as a multiplication of a separate stride and size.) + # Naively, you might expect to bind a fresh Sympy symbol for + # every input, but this is fairly wasteful as most of these + # symbols immediately simplify away, and if you don't eagerly + # specialize, e.g., 0/1 symbols, you end up with very complicated + # expressions that are not optimizable in practice. + # + # - You perform some compute on these symbols, occasionally + # introducing guards on boolean expressions on these symbols. + # In particular, whenever we guard on equality (_maybe_guard_rel), + # we can simplify shapes; e.g., when s0 == s1 * 2, we can now + # replace all occurrences of s0 with s1 * 2. Sometimes, a + # boolean expression evaluation doesn't introduce a guard, as + # the guard is already entailed by the simplifications we have + # applied. + # + # - In the end, you have a bunch of replacements (saying how to + # simplify shapes) and a bunch of guards (all the equality guards + # are trivial, because they're covered by the replacements). + # + # From the ShapeEnv, we must generate a Python expression that, when + # evaluated on a set of inputs, tells us whether or not these boolean + # expressions would have evaluated in the same way. However, + # we cannot easily compute this, as we elide recording boolean + # expressions when we think they are vacuously true. Thus, we seek + # an approximation: we must generate an expression, if true, would have + # produced an "equivalent" ShapeEnv, which would answer guard + # expressions in the same way. + # + # Our notion of equivalence is a bit subtle. For example, consider + # the ShapeEnv created from an input of size (5, 4) versus (4, 4) + # (no other guards.) Duck sizing would generate (s0, s1) in the first + # case but (s0, s0) in the second. We do NOT assume that size + # variables are disjoint; so in fact a graph that assumes the input + # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not + # vice versa. However, consider an analogous case (1,) versus (2,). + # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT + # subsume the (1,) graph because we assume that any size variables + # is NOT 0/1 (and make simplifications according to this; e.g., if + # we queried s0 == 0, we would immediately return False without + # returning a guard.) + # + # So, it is perhaps easier to flip things on their head: the guard + # expressions we generate here say what simplifications are valid, + # and what are not. Below, we explain each of the guard expressions + # we generate + + # TODO: Make this more efficient by binding all the size/stride/offsets + # to locals before performing tests on them. + + from torch._dynamo.source import TensorPropertySource, TensorProperty + + # Actual codegen must be delayed as we don't necessarily know what + # the symbol mapping is + input_guards = [] + + symbol_to_source = collections.defaultdict(list) + symbol_to_constraints = collections.defaultdict(set) + constraint_violations : List[Tuple[bool, str, Callable[[], str]]] = [] + + def record_constraint_violation(warn_only, debug_name, msg, hint=None): + constraint_violations.append( + (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) + ) + + def is_dim(src): + return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE + + if equalities_inputs: + source_index = {} + for i, src in enumerate(sources): + source_index[src.name()] = i + + def get_expression(tensor_dim_src): + fake = placeholders[source_index[tensor_dim_src.base.name()]] + symint = fake.shape[tensor_dim_src.idx] + if isinstance(symint, torch.SymInt): + return symint.node.expr + else: + assert type(symint) is int, f"Expected int, got {type(symint)}" + return symint + + for src1, src2 in equalities_inputs.source_pairs: + expr1, expr2 = get_expression(src1), get_expression(src2) + # Check whether given input shape values satisfy a specified equation s = s'. + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) + if not concrete_val: + raise ConstraintViolationError( + f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" + " is not equal to " + f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" + ) + + for src, root, fn in equalities_inputs.derived_equalities: + expr1 = get_expression(src) + # recall that root is either a phantom symbol or an input source + expr2, debug_name = ( + (root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol) + else (get_expression(root), self._debug_name(root)) + ) + expr2_ = fn(expr2) + # Check whether given input shape values satisfy a specified equation s = fn(s'). + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) + if not concrete_val: + raise ConstraintViolationError( + f"Expected input {src.name()} to be equal to " + f"{fn(sympy.Symbol(debug_name))}, " + f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " + f"but got {expr1.xreplace(self.var_to_val)}" + ) + + for phantom_symbol in equalities_inputs.phantom_symbols: + # we created additional phantom symbols that are not input shape dimensions + symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol]) + + # How do we know what the value of s0 is? Fresh variables can only be + # bound by inputs, so there MUST be some other input which binds the + # variable. If there is no such input, this is an error in our + # system. We record where all symbols come from, to help you diagnose + # why those symbols didn't occur. + # + # In fact, generally speaking it is only possible for the "outermost" + # user of a ShapeEnv to evaluate the guards, because some inputs may + # not be available to inner levels. For example, Dynamo can guard on + # tensors that never actually become graph arguments (they are + # pruned). In this case, only Dynamo knows about these arguments. + def track_symint(source, val, constraint=None): + log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) + assert not isinstance(val, SymInt) or is_symbolic(val) + + if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: + val = val.node.maybe_as_int() + + if isinstance(val, SymInt): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + if ( + constraint is not None + and not isinstance(constraint, RelaxedUnspecConstraint) + ): + symbol_to_constraints[s].add(constraint) + else: + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + # try inferring the ranges of the expr s + sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols} + if any(vr is None for vr in sym_vrs.values()): + # some of the free symbols in s don't have ranges + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + if s.is_number: + i = int(s) + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if i not in (0, 1): + constraint_violated = True + if constraint_violated: + def hint(s): + sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s) + return f"{sexpr}." + + var_with_range = self._render_range_for_constraint_violation(source, constraint) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be equal to " + ) + record_constraint_violation( + constraint.warn_only, + self._debug_name(source), + msg, + hint=functools.partial(hint, s), + ) + + input_guards.append((source, s)) + else: + s = sympy.Integer(val) + input_guards.append((source, s)) + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + if not (s == constraint.vr.lower == constraint.vr.upper): # allow static constraints + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if val not in (0, 1): + constraint_violated = True + if constraint_violated: + var_with_range = self._render_range_for_constraint_violation(source, constraint) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be a constant ({val})." + ) + record_constraint_violation(constraint.warn_only, self._debug_name(source), msg) + + def track_symfloat(source, val): + log.debug("track_symfloat %s %s", LazyString(source.name), val) + assert not isinstance(val, SymFloat) or is_symbolic(val) + + if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: + val = val.node.maybe_as_float() + + if isinstance(val, SymFloat): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + input_guards.append((source, s)) + else: + s = sympy.Float(val) + input_guards.append((source, s)) + + for t, source, context in zip(placeholders, sources, input_contexts): + if isinstance(source, str): + from torch._dynamo.source import LocalSource + source = LocalSource(source) + assert isinstance(source, Source) + if t is None: + continue + if isinstance(t, (SymInt, int)): + track_symint(source, t) + continue + elif isinstance(t, (SymFloat, float)): + track_symfloat(source, t) + continue + assert isinstance(t, Tensorlike) + if is_traceable_wrapper_subclass(t): + from torch._dynamo.source import AttrSource + + assert isinstance(context, SubclassSymbolicContext) + + # For subclasses, we need to track symints on BOTH the outer + # and inner tensors. + sources_tensors_constraints = [ + (source, t, context.constraint_sizes, context.constraint_strides) + ] + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + inner_context = context.inner_contexts[attr] + sources_tensors_constraints.append(( + AttrSource(source, attr), + inner_t, + inner_context.constraint_sizes, + inner_context.constraint_strides + )) + else: + sources_tensors_constraints = [(source, t, context.constraint_sizes, context.constraint_strides)] + + for src, curr_t, constraint_size, constraint_stride in sources_tensors_constraints: + if is_sparse_any(curr_t): + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource(src, TensorProperty.SIZE, i) + track_symint(property_source, ss, constraint_size[i]) + else: + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource(src, TensorProperty.SIZE, i) + track_symint(property_source, ss, constraint_size[i]) + for i, ss in enumerate(curr_t.stride()): + property_source = TensorPropertySource(src, TensorProperty.STRIDE, i) + track_symint(property_source, ss, constraint_stride[i]) + track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset()) + + # 1. Every input must equal the final simplified symbolic expression + # stored on the placeholder. Given a placeholder (s0*2, s1), + # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. + # This does a lot of work: it covers duck sizing and equality guards. + exprs = [] + self.dim_constraints = DimConstraints( + symbol_to_source, + self.var_to_val, + set(symbol_to_constraints.keys()), + self.source_name_to_debug_name, + ) + + if not _simplified: + for source, expr in input_guards: + if self._translation_validation_enabled: + # Ignore sources that were not turned into SymInts. + srcname = source.name() + if srcname in self.source_to_symbol: + self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr)) + + # Small optimization + if ( + isinstance(expr, sympy.Symbol) and + symbol_to_source.get(expr) and + source == symbol_to_source[expr][0] + ): + continue + + # This logic excludes static values found on tensors from guarding, because + # dynamo's check_tensor_fn does that (see guards.cpp). + # However, for non tensor sources, we still need to guard here. + if ignore_static and isinstance(source, TensorPropertySource): + if expr.is_number: + self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}") + continue + + if is_dim(source): + self.dim_constraints.add_equality(source, expr) + + sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + exprs.append(f"{source_ref(source)} == {sexpr}") + if ( + isinstance(source, TensorPropertySource) + and source.prop is TensorProperty.SIZE + and equalities_inputs + and len(expr.free_symbols) == 1 + ): + symbol = next(iter(expr.free_symbols)) + if ( + isinstance(expr, sympy.Symbol) and + expr in symbol_to_constraints and + not equalities_inputs.is_equal(source, symbol_to_source[expr][0]) + ): + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " + "must always be equal." + ) + record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) + + if ( + not isinstance(expr, sympy.Symbol) and + symbol in symbol_to_constraints and + not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.xreplace({symbol: x})) + ): + src = symbol_to_source[symbol][0] + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} must always be related to " + f"the values of {self._debug_name(src)} = {src.name()} by " + f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." + ) + record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg) + + # NB: Not necessary to report constraint violations here: + # constraints are guaranteed to be on symbols (we've already + # caught constants and non-atomic expressions), so we only + # have relational constraints, but we don't support those + # at the moment + + # 2. Every guard must evaluate to True (but remember many guards + # like s0 == s1*2 because trivial due to simplification) + issued = set() + + def issue_guard(guard: ShapeGuard) -> None: + expr = self.simplify(guard.expr) + + # Avoid re-issueing the same guard. + if expr in issued: + return + + issued.add(expr) + + try: + is_trivial = False + if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]): + is_trivial = self.dim_constraints.add(expr) + guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + exprs.append(guard_expr) + self._add_target_expr(expr) + # A non-relational constraint on a single sizevar can violate + # a constraint + if not is_trivial and len(expr.free_symbols) == 1: + symbol = next(iter(expr.free_symbols)) + source = symbol_to_source[symbol][0] + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + var_with_range = self._render_range_for_constraint_violation(source, c) + msg = ( + f"Not all values of {var_with_range} " + f"satisfy the generated guard {guard_expr}." + ) + record_constraint_violation(c.warn_only, self._debug_name(source), msg) + elif isinstance(c, RelaxedUnspecConstraint): + # This is fine, we allow guards here as long as it + # didn't constrain it to one value (we don't + # actually know this; this depends on our + # ValueRanges reasoning capability) + pass + else: + raise AssertionError(f"unrecognized constraint {c}") + except Exception: + self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format())) + raise + + # First, issue all guards. + # This removes all the checks that follow from bounds + # We could simply emit those and also the bounds 2 <= size when necessary + for guard in (guards if guards is not None else self.guards): + if self._maybe_evaluate_static(guard.expr, axioms=()) is not None: + continue + issue_guard(guard) + + # Because there are guards that export's constraint solver can suggest good fixes for, that we may have + # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards), + # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts, + # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide + # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph). + for ra in self.deferred_runtime_asserts.get(None, []): + if self._maybe_evaluate_static(ra.expr, axioms=()) is not None: + continue + expr = self.simplify(ra.expr) + self.dim_constraints.add(expr) + + # 3. Every symbol must be within its value range (this handles 0/1 + # specialization too). + for symbol, sources in symbol_to_source.items(): + r = self.var_to_range.get(symbol) + if r is None: + if symbol not in self.var_to_range: + continue + r = self.var_to_range[symbol] + + assert sources + bounds = [] + if r.lower not in (-sympy.oo, -int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Ge(symbol, r.lower)) + # Only print lower bound in simplified mode if it is not the + # default + if not _simplified or r.lower != self._default_value_range().lower: + bounds.append(str(r.lower)) + bounds.append(source_ref(sources[0])) + if r.upper not in (sympy.oo, int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Le(symbol, r.upper)) + # nontrivial upper bound is always interesting + bounds.append(str(r.upper)) + if len(bounds) > 1: + exprs.append(" <= ".join(bounds)) + + # Check constraints + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + # TODO: With int_oo, I think this condition is a noop + # now + if not (c.vr & self._default_value_range()).issubset(r): + source = sources[0] + + expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper)) + guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr) + var_with_range = self._render_range_for_constraint_violation(source, c) + msg = ( + f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" + ) + record_constraint_violation( + c.warn_only, + self._debug_name(source), + msg, + ) + # We NaN specialize, which means similar to 0/1 specialization we + # should assume that the float is NOT nan. This is load bearing + # if you have something like an equality guard, nan will play + # merry hell with the reasoning. + if symbol_is_type(symbol, SymT.FLOAT): + exprs.append(f"not __math_isnan({source_ref(sources[0])})") + + if constraint_violations: + warn_msgs = [] + error_msgs = [] + debug_names = set() + for warn_only, debug_name, msg in constraint_violations: + if warn_only: + msg = f" {len(warn_msgs) + 1}. {msg()}" + warn_msgs.append(msg) + else: + msg = f" - {msg()}" + error_msgs.append(msg) + debug_names.add(debug_name) + if len(error_msgs) > 0: + debug_names = ', '.join(sorted(debug_names)) + err = '\n'.join(error_msgs) + raise ConstraintViolationError( + f"Constraints violated ({debug_names})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + f"{err}" + ) + elif len(warn_msgs) > 0: + log.debug("%s Warning only constraints violated", len(warn_msgs)) + + signpost_event( + "dynamic", + "produce_guards", + { + **self.co_fields, + **self.counter, + "num_guards": len(exprs), + "free_symbols": sum(1 for v in symbol_to_source.values() if v), + # The keys are meaningless from an aggregate perspective, so + # don't include them. Biggest first. + "symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True), + }, + ) + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import PopulateValidator + + # Add all deferred runtime assertions; these are not technically + # handled by produce_guards but we need to put them in the target + # set + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + self._add_target_expr(ra.expr) + + # Add value range bound guards for all symbols with no trivial bounds. + # Reason: '_maybe_evaluate_static' may eliminate guards based on the + # refined value ranges. + for sym, vr in self.var_to_range.items(): + if vr.lower not in (-sympy.oo, -int_oo): + self._add_target_expr(sympy.Le(vr.lower, sym)) + if vr.upper not in (sympy.oo, int_oo): + self._add_target_expr(sympy.Le(sym, vr.upper)) + + # Before validating, populate the input of the validator with the + # built FX graph. + with fx_traceback.preserve_node_meta(): + PopulateValidator(self.graph, self.validator).run() + + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() + return exprs + + def produce_guards_expression( + self, + placeholders, + *, + guards: Optional[List[ShapeGuard]] = None, + ignore_static=True + ): + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + arg_names = [f"t{i}" for i in range(len(placeholders))] + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) + return None + + def evaluate_symexpr(self, code): + """ + To be used by compile_fx to evaluate symexprs + """ + args = {str(e): val for e, val in self.var_to_val.items()} + return eval(code, SYMPY_INTERP, args) + + def evaluate_guards_expression(self, code, args): + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + + def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True): + """Generate guards for a graph's placeholder values and evaluate the guards with args + """ + code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) + if code: + return self.evaluate_guards_expression(code, args) + return True + + def get_pruned_guards(self, symints): + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + symints = {s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)} + guards = [] + for g in self.guards: + if all(s in symints for s in g.expr.free_symbols): + guards.append(g) + return guards + + def bind_symbols(self, placeholders, args): + """ + Given a paired list of placeholders (fake tensors with + symbolic sizes) and concrete arguments (regular tensors + with real sizes), returns a dictionary mapping each + symbol to its real value. So for example, if you + have a placeholder with size (s0, s1), binding + (2, 4) to it will give you {s0: 2, s1: 4}. This is + not guaranteed to bind ALL symbols in the ShapeEnv; + we can't bind a symbol if it doesn't occur in any placeholder, + and symbols that already have replacements won't get bindings. + + This is a little duplicative with evaluate_guards but + it's different enough that it seemed cleanest to make + another copy. This assumes the guards are already checked, + though if it's cheap we'll check for shenanigans + """ + bindings: Dict[sympy.Symbol, int] = {} + + def bind_symint(arg, val): + if isinstance(val, SymInt): + s = val.node.expr + + if isinstance(s, sympy.Symbol): + if s in bindings: + assert bindings[s] == arg, f"{bindings[s]} != {arg}" + else: + bindings[s] = arg + elif isinstance(-s, sympy.Symbol): + if -s in bindings: + assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" + else: + bindings[-s] = -arg + + for t, arg in zip(placeholders, args): + if t is None: + continue + if isinstance(t, SymInt): + bind_symint(arg, t) + continue + assert isinstance(t, torch.Tensor) + for i, s in enumerate(t.size()): + bind_symint(arg.size(i), s) + for i, s in enumerate(t.stride()): + bind_symint(arg.stride(i), s) + bind_symint(arg.storage_offset(), t.storage_offset()) + + return bindings + + def get_nontrivial_guards(self): + """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" + return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr, axioms=()) is None] + + def format_guards(self, verbose=False): + """Format this shape env's guard expressions with optional traceback info if verbose""" + def format_tb(tb): + if not verbose: + return "" + return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}" + + return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards) + + def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges: + """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} + if size_oblivious: + # Clamp values of size-like variables + # NB: discarding the old upper bound in intentional, per + # https://github.com/pytorch/pytorch/pull/123675 + for x in self.size_like & var_to_range.keys(): + if var_to_range[x] is not None: + # NB: do NOT set upper to 2 ** 48, we're using this solely + # to determine if we can do size-like replacement, the + # upper bound is irrelevant here + var_to_range[x] = ValueRanges(2, int_oo) + assert var_to_range[x].is_int + return bound_sympy(expr, var_to_range) + + @_lru_cache + def get_axioms(self, symbols: Optional[Tuple["sympy.Symbol"]] = None, compute_hint: bool = False) -> Tuple["sympy.Expr"]: + """ + Given the symbols in an expression, it returns all the runtime asserts that have those symbols + concatenated with all the guards. + If symbols is None, it returns all the runtime asserts (and all the guards) + """ + if symbols is None: + runtime_asserts = (r.expr + for rs in self.deferred_runtime_asserts.values() + for r in rs) + else: + runtime_asserts = (r.expr + for s in symbols if s not in self.var_to_val + for r in self.deferred_runtime_asserts.get(s, ())) + guards = (g.expr for g in self.guards) + axioms = itertools.chain(guards, runtime_asserts) + if compute_hint: + axioms = (canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms) + return tuple(dict.fromkeys(axioms).keys()) + + @lru_cache(None) + def get_implications(self, + e: "sympy.Expr") -> Tuple[Tuple["sympy.Expr", 'sympy.logic.boolalg.BooleanAtom']]: + """ Given a expression, it returns a list of predicates that follow from it """ + equiv = {} + + def add_expr(expr): + expr = canonicalize_bool_expr(expr) + if isinstance(expr, (sympy.Eq, sympy.Ne)): + # No need to canonicalize + # TODO We could further canonicalize Eq ordering the lhs and rhs somehow + # With this, we could remove the need for the commutativity part + opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne + # Commutativity of == and != + equiv[type(expr)(expr.lhs, expr.rhs)] = sympy.true + equiv[type(expr)(expr.rhs, expr.lhs)] = sympy.true + equiv[opposite(expr.lhs, expr.rhs)] = sympy.false + equiv[opposite(expr.rhs, expr.lhs)] = sympy.false + else: + # Expr and negation + equiv[expr] = sympy.true + equiv[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false + + add_expr(e) + # Other relational expressions this expression implies + if isinstance(e, sympy.Eq): + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ge(e.lhs, e.rhs)) + elif isinstance(e, sympy.Lt): + add_expr(sympy.Le(e.lhs, e.rhs)) + add_expr(sympy.Ne(e.lhs, e.rhs)) + if e.lhs.is_integer and e.rhs.is_integer: + add_expr(sympy.Le(e.lhs, e.rhs - 1)) + elif isinstance(e, sympy.Le): + add_expr(sympy.Lt(e.lhs, e.rhs + 1)) + return tuple(equiv.items()) + + @_lru_cache + def _maybe_evaluate_static( + self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, + size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, + var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None + ) -> "Optional[sympy.Expr]": + """ + Tries to evaluate expr without introducing guards + + If unbacked_only == True, then we only do substitutions on + unbacked SymInts (leaving regular hinted integers alone). This could + result in an expression that still contains backed SymInts, which you + could then potentially guard on. + + Use compute_hint == True if you are trying to compute a non-binding + hint for the particular hint values of backed SymInts, e.g., if + s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. + """ + + # axioms with compute hint NYE + assert not compute_hint or not axioms + + if var_to_range is None: + var_ranges = self.var_to_range + else: + var_ranges = dict(var_to_range) + + expr = self.simplify(expr) + + if compute_hint: + expr = expr.xreplace(self.var_to_val) + + expr = canonicalize_bool_expr(expr) + + # Pattern matching + symbols = tuple(expr.free_symbols) + if axioms is None: + axioms = self.get_axioms(symbols, compute_hint=compute_hint) + subst = {} + for e in axioms: + if e.free_symbols.issubset(expr.free_symbols): + subst.update(dict(self.get_implications(e))) + + expr = expr.xreplace(subst) + + symbols = tuple(expr.free_symbols) + + # Simplify making use of value range lower bound + new_shape_env = {} + new_range_env = {} + for idx, k in enumerate(symbols): + if isinstance(self.var_to_val.get(k, None), SingletonInt): + # Skip var_ranges logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + vr = var_ranges[k] + if size_oblivious and k in self.size_like: + lower = max(2, vr.lower) + # Clamping size-oblivious to some quantity below sys.maxsize + # helps us determine that f(u0) != sys.maxsize, which is a + # test that is looking for sys.maxsize as a sentinel, but you + # don't really want to worry about it for unbacked SymInts. + # This is similar to the flavor where size oblivious omits + # 0/1, it changes semantics but in a benign way. + upper = min(2 ** 48, vr.upper) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= upper: + vr = ValueRanges(lower, upper) + else: + lower = vr.lower + # Don't do anything if we don't have a nontrivial lower bound + # Also don't do anything if we asked only to simplify unbacked + # SymInt + if ( + lower is -int_oo or + (unbacked_only and k in self.var_to_val) or + not vr.is_int + ): + new_range_env[k] = vr + continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # + # Positive means >= 1 + # Positive - 1 means >= 0 + # Positive + lower - 1 means >= lower + # The new symbol 's' is "too low", so when we substitute it in + # we have to increase it by offset (and conversely, the new + # variables have to have their value range bounds adjusted as + # well) + s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True) + + # Note: + # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. + # Sympy might give unexepected results when comparing an integer with a non-integer + # Therefore, we cast offset to int here. + # For example: + # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) + # expr = sympy.Eq(shape_0 - 1/3, 4) + # expr.xreplace({}) # False + offset = int(lower - 1) + new_shape_env[k] = s + offset + new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + + try: + new_expr = expr.xreplace(new_shape_env) + except RecursionError: + log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) + self.counter["sympy_recursion_error"] += 1 + return None + + # We need to canonicalize, as after expand we may have something like `a + b = a` and + # sympy will not simplify the a. The two appeareances of the a will then make value ranges + # analysis give lose bounds + new_expr = canonicalize_bool_expr(safe_expand(new_expr)) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ + floor_div_replace = {} + for atom in new_expr.atoms(FloorDiv): + floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) + new_expr = safe_expand(new_expr.xreplace(floor_div_replace)) + # TODO: when unbacked_only, can sometimes early return even when there + # are still free symbols + if new_expr.is_number: + return new_expr + """ + + # Check if the range can solve it statically + out = bound_sympy(new_expr, new_range_env) + if out.is_singleton(): + return out.lower + + return new_expr if unbacked_only else None + + @_lru_cache + def replace(self, expr: "sympy.Expr") -> "sympy.Expr": + """Apply symbol replacements to any symbols in the given expression + """ + replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} + return safe_expand(expr.xreplace(replacements)) + + @_lru_cache + def _update_divisible(self): + new_divisible = set() + for k in self.divisible: + res = self.replace(k) + if not res.is_number: + new_divisible.add(k) + + self.divisible = new_divisible + self._update_version_counter() + + @_lru_cache + def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": + """Use known constraints and replacements to simplify the given expr + """ + expr = self.replace(expr) + # TODO it would seem that this pass is not necessary given the + # below replacement of // with /, but for nested FloorDivs + # the non-recursive replacement doesn't work, and + # recursive makes it hard to look up divisibility, + # because existing divisibility info has FloorDiv in it, not / + # for now just do a separate pass to catch common nested case + if expr.has(FloorDiv): + self._update_divisible() + div_replacements = {} + for atom in expr.atoms(FloorDiv): + base, divisor = atom.args + if isinstance(divisor, FloorDiv): + base1, divisor1 = divisor.args + if self.replace(Mod(base, divisor)) in self.divisible and \ + base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: + div_replacements[atom] = divisor1 + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) + if expr.has(FloorDiv): + div_replacements = {} + pows = expr.atoms(sympy.Pow) + rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) + for fd in expr.atoms(FloorDiv): + base, divisor = fd.args + if self.replace(Mod(base, divisor)) in self.divisible: + div_replacements[fd] = CleanDiv(base, divisor) + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr + return expr + + @lru_cache(256) + def size_hint(self, expr: "sympy.Expr", *, allow_none=False): + """ + Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + result_expr = safe_expand(expr).xreplace(self.var_to_val) + if not result_expr.is_number: + + from torch.utils._sympy.singleton_int import SingletonInt + + if isinstance(result_expr, SingletonInt): + return None + r = self._maybe_evaluate_static(result_expr, compute_hint=True) + if r is not None: + return r + if allow_none: + return None + + if self.unbacked_var_to_val: + unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) + if not unsound_expr.free_symbols: + log.warning("propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(expr), + "result": repr(unsound_expr), + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + }, + ) + self.defer_runtime_assert( + sympy.Eq(result_expr, unsound_expr), + f"propagate_real_tensors: {result_expr} == {unsound_expr}" + ) + return unsound_expr + + raise self._make_data_dependent_error(result_expr, expr) + return result_expr + + # NB: keep in sync with size_hint + @lru_cache(256) + def has_hint(self, expr: "sympy.Expr"): + result_expr = safe_expand(expr).xreplace(self.var_to_val) + return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None + + def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None): + # TODO: in a Dynamo context, having user code, and having the + # name of the local, will be much better + size_like_symbols = [] + for s in expr.free_symbols: + stacktrace = ''.join(self.var_to_stack[s].format()) + self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace) + if s in self.size_like: + size_like_symbols.append(s) + size_oblivious_result_msg = "" + if size_oblivious_result is not None: + size_oblivious_result_msg = ( + f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" + "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" + ) + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True) + if expr.is_integer: + desc = "Could not extract specialized integer from data-dependent expression" + else: + desc = "Could not guard on data-dependent expression" + msg = ( + f"{desc} {expr} (unhinted: {unhinted_expr}). " + f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" + f"{size_oblivious_result_msg}" + "Potential framework code culprit (scroll up for full backtrace):\n" + f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' + "For extended logs when we create symbols, also add " + f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" + "For more debugging help, see " + "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + + maybe_extra_debug + # TODO: Help text about how to use our runtime tests to fix this + # problem + ) + return GuardOnDataDependentSymNode(expr, msg) + + def _update_var_to_range(self, symbol, vr): + lower, upper = vr.lower, vr.upper + + # If we have a size-like unbacked SymInt, refuse to refine the range to be + # less than two. This is because when we intersect this range + # with [2, inf] for size oblivious tests, the range would be + # unsatisfiable. In other words, once you have a size-like + # unbacked SymInt, we can never learn that it is exactly zero or one, + # because we would now give inconsistent results for all size + # oblivous tests! + if upper < 2 and symbol in self.size_like: + upper = 2 + + # Updates the range and the guards corresponding to each bound of the symbol. + if symbol not in self.var_to_range: + r = ValueRanges(lower, upper) + self.log.debug("_update_var_to_range %s = %s (new)", symbol, r) + self.var_to_range[symbol] = r + else: + old = self.var_to_range[symbol] + new = old & ValueRanges(lower, upper) + if new != old: + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) + + if (v := self.var_to_val.get(symbol)) is not None: + r = self.var_to_range[symbol] + assert v in r, f"{v} not in {r}" + + def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + + if tgt == self.replacements.get(a, None): + return + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug("set_replacement: solve for %s in %s == %s gives %s", b, a, tgt, r) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) + self._update_var_to_range(b, b_bound) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset(src_bound) + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug("skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + "user_stack": structured.from_traceback(user_tb) if user_tb else None, + } + ) + + if config.print_specializations: + self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt) + self.log.debug("SPECIALIZATION", stack_info=True) + log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt)) + + def _add_divisible(self, expr: "sympy.Expr"): + self.divisible.add(expr) + self._update_version_counter() + + @_lru_cache + @record_shapeenv_event() + def _find(self, a: "sympy.Symbol") -> "sympy.Expr": + """ + Implements a DSU-like algorithm to find the variable that represents a + Also handles transitive non-identity replacements. + + a: b + c + c: d + """ + if a not in self.replacements: + return a + res = self.replacements[a] + cur_replace = {s: self._find(s) for s in res.free_symbols} + replaced, changed = self.replacements[a]._xreplace(cur_replace) + if changed: + self._set_replacement(a, replaced, "find") + return self.replacements[a] + + @lru_cache(256) + def _maybe_guard_rel(self, expr: "sympy.Rel") -> None: + """ + The relational guard is guarded to be true. Use this information to + simplify shapes (i.e. a == b or a % 5 == 0) + """ + assert isinstance(expr, sympy.Rel) + + # A good example of what goes wrong if you don't do this is + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + if isinstance(expr, sympy.Ne): + return + + free = list(expr.free_symbols) + + assert len(free) > 0, f"The expression should not be static by this point: {expr}" + # In case of really gnarly expression, we don't blow up + if len(free) > 5: + return + + # Prioritize unbacked symints for solving by ordering them last. + # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). + # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) + # Prefer to simplify out symbols with ephemeral sources. + def _smart_symbol_sort(x): + has_only_ephemeral_sources = ( + x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x]) + ) + # NB: size_hint is int, not sympy.Expr, do not use int_oo here + size = self.size_hint(x, allow_none=True) or sys.maxsize + name = x.name + # 1 puts ephemeral sourced symbols first when sorting in reverse + return (1 if has_only_ephemeral_sources else 0, size, name) + + free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] + lhs = expr.lhs + rhs = expr.rhs + + self._refine_ranges(expr) + + # The rest of this stuff is for equality only + if not isinstance(expr, sympy.Eq): + return + + if not expr.has(Mod): + try: + floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) + if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms): + raise NotImplementedError + + # Never replace unbacked symbols with other unbacked symbols. + # This is error prone because you can cause references to + # unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. It's pretty inconvenient to setup control + # dependencies for substitutions, so ban it entirely. + def trivial_solve(lhs, rhs): + if isinstance(lhs, sympy.Symbol): + if free_unbacked_symbols(lhs) and not free_unbacked_symbols(rhs): + return True + if symbol_is_type(lhs, SymT.FLOAT): + return True + # TODO: Maybe trivial solutions for int should also be + # done? + return False + + # short-circuit when no solving is needed + if trivial_solve(lhs, rhs): + self._set_replacement(lhs, self._find(rhs), "trivial_lhs") + elif trivial_solve(rhs, lhs): + self._set_replacement(rhs, self._find(lhs), "trivial_rhs") + else: + r = try_solve(expr, free[0], floordiv_inequality=False) + if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): + new_var = self._find(r[1]) + ok = len(free_unbacked_symbols(new_var)) == 0 + if ok: + self._set_replacement(cast(sympy.Symbol, free[0]), new_var, "solve") + except NotImplementedError: + pass + if expr.has(Mod): + mod_expr = next(iter(expr.atoms(Mod))) + try: + r = try_solve(expr, mod_expr, floordiv_inequality=False) + if r is not None and r[1] == 0: + self._add_divisible(mod_expr) + # This is a little bit of extra logic to make things like + # torch.empty(i0, q).view(c, -1, q) work out + p, q = mod_expr.args + if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2: + c, i0 = p.args + # Given Mod(c * i0, q) == 0 + if ( + isinstance(c, sympy.Number) and + isinstance(i0, sympy.Symbol) and + self.is_unbacked_symint(i0) + ): + # We have Mod(i0, q / c) == 0, which means we can + # rewrite i0 as (q / gcd(q, c)) * i1 + d = q / sympy.gcd(q, c) # TODO: CleanDiv? + i1 = self.create_unbacked_symint().node.expr + # Propagate the value ranges. It doesn't really + # matter if we use truediv or floordiv, because we + # have established divisibility. + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( + self.var_to_range[i0], ValueRanges.wrap(d) + )) + # Propagate size-like-ness + if i0 in self.size_like: + self.size_like.add(i1) + self._set_replacement(i0, d * i1, "divisibility") + + except NotImplementedError: + pass + return + + # See: Note - On 0/1 specialization + def _default_value_range(self) -> ValueRanges: + lower = 2 if self.specialize_zero_one else 0 + return ValueRanges(lower, int_oo) + + def _default_unspecified_value_range(self) -> ValueRanges: + return ValueRanges(-int_oo, int_oo) + + @_lru_cache + def _simplify_floor_div(self, expr): + floor_divs = tuple(expr.atoms(FloorDiv)) + # we expect floor_divs to be exact, + # and thus add the guards for the exact floordivs, + # even if tracing doesn't require them otherwise + for fd in reversed(floor_divs): + base, divisor = fd.args + mod_expr = Mod(base, divisor) + eq_expr = sympy.Eq(mod_expr, 0) + # add necessary mod guards + self.evaluate_expr(eq_expr) + return self.simplify(expr) + + # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen + # and if so issue a warning + def _check_frozen(self, expr, concrete_val): + if self.frozen: + self.counter["ignored_backward_guard"] += 1 + signpost_event( + "dynamic", + "evaluate_expr_frozen", + { + **self.co_fields, + "ignored_guard": f"{expr} == {concrete_val}", + # no version = original state (this signpost is expected) + # version 2 = dynamic backwards is eagerly compiled + "version": 2, + }, + ) + log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val, stack_info=True) + + + def _get_stack_summary(self, is_debug: bool = False): + fsummary = None + frame = inspect.currentframe() + try: + while frame is not None: + if frame.f_code.co_filename not in uninteresting_files(): + fsummary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + break + frame = frame.f_back + finally: + del frame + + # NB: this stack is truncated, but it's fine because the main + # stack_info will give you the rest of the info you need + maybe_user_loc = "" + user_tb = TracingContext.extract_stack() + if user_tb: + maybe_user_loc = " at " + format_frame(user_tb[-1]) + + maybe_extra_debug = "" + if is_debug and user_tb: + maybe_extra_debug = ( + '\nUser Stack (most recent call last):\n' + + ' (snipped, see stack below for prefix)\n' + + ''.join(traceback.format_list(user_tb)) + ) + if is_debug and config.extended_debug_cpp: + cpp_stack = CapturedTraceback.extract(cpp=True) + maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format()) + elif is_debug: + maybe_extra_debug += ( + "\nFor C++ stack trace, run with " + "TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" + ) + + return fsummary, maybe_user_loc, maybe_extra_debug + + def _log_guard(self, prefix: str, g, forcing_spec: bool): + if self.log.isEnabledFor(logging.INFO): + str_g = str(g) + is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added + fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug) + maybe_more_info = "" + if not is_debug: + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' + ) + self.log.info( + "%s %s [guard added]%s (%s)%s%s", + prefix if not forcing_spec else f"{prefix} (forcing_spec)", + str_g, + maybe_user_loc, + format_frame(fsummary), + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True) + def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, + size_oblivious: bool = False, *, forcing_spec: bool = False): + try: + return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) + except Exception: + self.log.warning( + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, hint, size_oblivious, forcing_spec + ) + raise + + def _evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, + size_oblivious: bool = False, *, forcing_spec: bool = False): + """ + Given an expression, evaluates it, adding guards if necessary + """ + + # TODO: split conjunctions and evaluate them separately + + # Don't track this one + @functools.lru_cache(None) + def compute_concrete_val(): + if hint is None: + return self.size_hint(orig_expr) + else: + return sympy.sympify(hint) + + # Check if: + # 1. 'translation_validation' is set + # 2. the corresponding 'fx_node' is not 'None' + # 3. the guard should not be suppressed + # + # If all of the above check, we create an FX node representing the + # actual expression to be guarded. + node = None + fresh = False + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + and not size_oblivious + ): + concrete_val = compute_concrete_val() + if concrete_val is sympy.true: + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + elif concrete_val is sympy.false: + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) + else: + eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val)) + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) + + assert node is not None + # If this is a fresh node, we have to remember the event index that + # corresponds to this assertion node. + # Reason: so that, given an assertion node, we can replay the ShapeEnv + # events until the point where this assertion node was freshly created. + if fresh: + self._add_fx_node_metadata(node) + + # After creating the FX node corresponding to orig_expr, we must make sure that + # no error will be raised until the end of this function. + # + # Reason: the translation validation may become invalid otherwise. + # + # If an error is raised before the end of this function, we remove the FX node + # inserted, and re-raise the error. + guard = None + tb = None + + try: + if orig_expr.is_number: + self.log.debug("eval %s [trivial]", orig_expr) + if hint is not None: + assert orig_expr == hint, f"{orig_expr} != {hint}" + return orig_expr + + expr = orig_expr + + static_expr = self._maybe_evaluate_static(expr, + size_oblivious=size_oblivious) + if static_expr is not None: + self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) + if hint is not None: + assert static_expr == hint, f"{static_expr} != {hint}" + return static_expr + + transmute_into_runtime_assert = False + + concrete_val = None + if not (expr.free_symbols <= self.var_to_val.keys()): + # TODO: dedupe this with _maybe_evaluate_static + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + if not (new_expr.free_symbols <= self.var_to_val.keys()): + size_oblivious_result = None + if not size_oblivious: + size_oblivious_result = self._maybe_evaluate_static( + expr, + size_oblivious=True + ) + + # Last ditch + if ( + self.unbacked_var_to_val and + not (unsound_result := orig_expr.xreplace(self.unbacked_var_to_val)).free_symbols + ): + log.warning("propagate_real_tensors evaluate_expr(%s) -> %s", orig_expr, unsound_result) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), + }, + ) + transmute_into_runtime_assert = True + concrete_val = unsound_result + else: + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + size_oblivious_result=size_oblivious_result + ) + else: + expr = new_expr + + if concrete_val is None: + concrete_val = compute_concrete_val() + self._check_frozen(expr, concrete_val) + + if ( + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY + and isinstance(hint, bool) + and isinstance(expr, (sympy.Eq, sympy.Ne)) + ): + expr = sympy.Not(expr) + + # Turn this into a boolean expression, no longer need to consult + # concrete_val + if concrete_val is sympy.true: + g = expr + elif concrete_val is sympy.false: + g = sympy.Not(expr) + else: + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] + + if transmute_into_runtime_assert: + self.defer_runtime_assert( + g, + f"propagate_real_tensors: {orig_expr} == {unsound_result}" + ) + return concrete_val + + if not self._suppress_guards_tls(): + if isinstance(g, sympy.Rel): + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) + + if not self.allow_complex_guards_as_runtime_asserts: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + stack = CapturedTraceback.extract(skip=1) + guard = ShapeGuard(g, stack) + self.guards.append(guard) + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + + except Exception: + if fresh: + self._remove_fx_node(node) + raise + else: + if not self._suppress_guards_tls(): + if guard is not None: # we might have deferred this to runtime assert + self._log_guard("eval", g, forcing_spec=forcing_spec) + + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec and + config.symbol_guard_limit_before_specialize is not None and + self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s + ) + self.evaluate_expr(s, forcing_spec=True) + else: + self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) + + return concrete_val + + def cleanup(self): + """ + Break reference cycles. + + This destroys the stacks. If you really want to keep them, we + just need some way to break references on code objects. + """ + for g in self.guards: + g.stack.cleanup() + for s in self.var_to_stack.values(): + s.cleanup() + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + ra.stack.cleanup() + + @record_shapeenv_event(save_tracked_fakes=True) + def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): + """Create an assert that is checked at runtime + + Args: + orig_expr (sympy.Expr): Boolean expression to assert is true + msg (str): Message to display on assertion failure + fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding + to the expression, if applicable + + """ + expr = orig_expr + + # TODO: split conjunctions and evaluate them separately + + static_expr = self._maybe_evaluate_static(expr) + if static_expr is not None: + self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr) + return static_expr + + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + if not self.prefer_deferred_runtime_asserts_over_guards and new_expr.free_symbols <= self.var_to_val.keys(): + # Do a normal guard + return self.evaluate_expr(new_expr, fx_node=fx_node) + # NB: Don't use new_expr as expr; it could contain gunk like shape0 + # which we don't want to guard on + + # OK, we're definitely doing a runtime assert now + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + ): + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + assert node is not None + if fresh: + self._add_fx_node_metadata(node) + + if not self._suppress_guards_tls(): + # If you're here because of this assert, read Note [Backwards runtime asserts] + # in torch/_inductor/graph.py + assert not self.runtime_asserts_frozen, expr + + self._check_frozen(expr, sympy.true) + + # eliminate symbols on equality tests / refine ranges + if isinstance(expr, sympy.Rel): + self._maybe_guard_rel(expr) + + # canonicalise to remove equations that are trivially equal + orig_expr = expr + expr = canonicalize_bool_expr(expr) + stack = CapturedTraceback.extract(skip=1) + ra = RuntimeAssert(expr, msg, stack) + # TODO: Do this in a way that is less janky than int(s.name[1:]) + cands = sorted((s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), key=lambda s: int(s.name[1:])) + # Is None when prefer_deferred_runtime_asserts_over_guards=True + # and the guard in question has no unbacked SymInts in front + ix = cands[-1] if cands else None + self.deferred_runtime_asserts.setdefault(ix, []).append(ra) + self.num_deferred_runtime_asserts += 1 + self._update_version_counter() + self._log_guard("runtime_assert", orig_expr, forcing_spec=False) + else: + self._log_guard("runtime_assert [guard suppressed]", orig_expr, forcing_spec=False) + + return True + + # Refines the ranges of the variables present in 'guard'. + # + # This function tries to refine the range of the variables inside + # 'guard' by reasoning about it. Specifically, when 'guard' is a + # 'sympy.Relational' operation. + # + # It does mainly 3 things: + # 1. Tries to isolate a variable in the left-hand side + # 2. Compute the value range of the right-hand side + # 3. Update the value range of the variable, if better + def _refine_ranges(self, expr: sympy.Expr) -> None: + expr = self.simplify(expr) + + for symbol in expr.free_symbols: + assert isinstance(symbol, sympy.Symbol) + + if isinstance(self.var_to_val.get(symbol, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + + r = try_solve(expr, symbol) + + if r is None or not (symbol.is_integer and r[1].is_integer): + # Range refinement only supports integer symbols for now. + # There are lots of SymPy bugs when it comes to comparing + # reals and integers, so we skip that for now. + continue + + r_expr, rhs = r + vr = self.var_to_range[symbol] + lower, upper = vr.lower, vr.upper + + rhs_vr = bound_sympy(rhs, self.var_to_range) + + # Let's suppose that we have a preexisting range for x [0, 100]. + # Now, we issue a guard x > y, where the range for y is [50, 150]. + # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, + # refining x to [51, 100], since x must be greater than y, but the lowest + # y could be is 50. + # + # sympy.Eq may update both lower and upper bounds. + # sympy.G{t,e} may update the lower bound, only. + # sympy.L{t,e} may update the upper bound, only. + if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)): + # Strictly greater relations allow us to refine a bit more, since + # x < y implies that the lower bound for x is: y + 1. + lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) + if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)): + upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) + + # Do nothing if the new value range is no better than what we already have. + if vr == ValueRanges(lower, upper): + continue + + # Updates the range and the guards corresponding to each bound of the symbol. + self._update_var_to_range(symbol, ValueRanges(lower, upper)) + # If the range is refined to singleton, set replacement + if self.var_to_range[symbol].is_singleton(): + self._set_replacement(symbol, self.var_to_range[symbol].lower, "range_refined_to_singleton") + + # Clears the cache, since this update can change the result. + self._maybe_evaluate_static.cache_clear() + + @lru_cache(maxsize=None) + @record_shapeenv_event() + def constrain_symbol_range(self, s: sympy.Symbol, compiler_min: int, compiler_max: int): + upd_vr = ValueRanges(compiler_min, compiler_max) + old_vr = self.var_to_range.get(s, ValueRanges.unknown()) + self._update_var_to_range(s, upd_vr) + if (new_vr := self.var_to_range[s]) != old_vr: + log.info("constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper) + + +def _is_int(expr): + return isinstance(expr, SymInt) and expr.node.expr.is_number + +# WARNING: This is legacy, DO NOT USE +def _is_dim_dynamic(t, d): + return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices + +class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run_node(self, n: torch.fx.Node): + """ + Run an FX node, propagating unbacked Symbol bindings to the new fake tensor + """ + from torch._guards import detect_fake_mode + + result = super().run_node(n) + rebind_unbacked(detect_fake_mode().shape_env, n, result) + return result + + +def _find_user_code_frame(): + frame = inspect.currentframe() + while frame is not None: + if not frame.f_code.co_filename.startswith( + os.path.dirname(inspect.getfile(torch)) + os.path.sep + ): + break + frame = frame.f_back + return frame + + +def _blame_user_code(e, frame): + frame_summary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + msg = e.args[0] + msg += ( + '\n\nThe following call raised this error:\n' + + ''.join(traceback.StackSummary.from_list([frame_summary]).format()) + ) + e.args = (msg,) + + +class _PythonPrinter(sympy.printing.str.StrPrinter): + """ + Util printer that replaces sympy symbols with their source-level names + and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline + (i.e., as ==, !=, >, <). + """ + + def __init__(self, src_map): + super().__init__() + self.src_map = src_map + + def _print_Symbol(self, sym): + return self.src_map[sym.name][0] + + def _print_Relational(self, expr): + lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr)) + rel_op = expr.rel_op + rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr)) + return f"{lhs} {rel_op} {rhs}" + + +def _suggest_torch_checks(e, src_map): + # extract the unresolved condition on unbacked symints in the error + cond = e.cond + diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) + if diff: + log.warning("Unable to find user code corresponding to {%s}", diff) + return + printer = _PythonPrinter(src_map) + msg = e.args[0] + msg += "\nTo fix the error, insert one of the following checks before this call:" + # suggested fixes to resolve `cond`` are to tell the compiler to assume + # either `cond` or its negation (the user will need to select which) + suggested_fixes = [ + f"torch._check({printer.doprint(cond)})", + f"torch._check({printer.doprint(sympy.Not(cond))})", + ] + for i, fix in enumerate(suggested_fixes): + msg += f"\n {i+1}. {fix}" + src_mapped = ', '.join( + f"`{s}` with {' or '.join(src_map[s])}" + for s in sorted(s.name for s in cond.free_symbols) + ) + msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)" + e.args = (msg,) + + +def _suggest_fixes_for_data_dependent_error_non_strict(e): + """ + Given a raised data-dependent error, add the following to the error message: + 1. the closest user code location that raised the error; + 2. suggested fixes for the error in terms of live variables at that location. + """ + + # walk the stack up from the data-dependent error until a non-torch frame is found + frame = _find_user_code_frame() + if frame is not None: + # add frame info to error message + _blame_user_code(e, frame) + + # map symbol names reachable via frame locals to their source-level names + src_map = defaultdict(list) + for var, val in frame.f_locals.items(): + # figure out how to access any symbol inside `val` through `var` + for path, leaf in pytree.tree_leaves_with_path(val): + name = var + pytree.keystr(path) + if isinstance(leaf, torch.SymInt): + src_map[str(leaf.node.expr)].append(name) + elif isinstance(leaf, torch.Tensor): + for i, dim in enumerate(leaf.shape): + if isinstance(dim, torch.SymInt): + src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") + + # add suggested torch.check()s based on `src_map` to the error message + # replacing unbacked symints in the unresolved condition in the error + _suggest_torch_checks(e, src_map) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py new file mode 100644 index 0000000000000000000000000000000000000000..cad0a33425bf894540fb3a84f0127a58b43ff237 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unify_refinements.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.tensor_type import TensorType +from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] + + +def infer_symbolic_types_single_pass(traced): + """ + Calls our symbolic inferencer once. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + +def infer_symbolic_types(traced): + """ + Calls our symbolic inferencer twice. + This is useful when one pass is not enough + to infer all the information such as the case + for braodcasting. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() + +def convert_eq(list_of_eq): + """ + Convert equality constraints in the right format + to be used by unification library. + """ + lhs = [] + rhs = [] + for eq in list_of_eq: + lhs.append(eq.lhs) + rhs.append(eq.rhs) + return tuple(lhs), tuple(rhs) + + +def unify_eq(list_of_eq): + """ + Apply unification to a set of + equality constraints + """ + lhs, rhs = convert_eq(list_of_eq) + return unify(lhs, rhs) + + +def substitute_solution_one_type(mapping, t): + """ + Apply the most general unifier to a type + """ + if isinstance(t, Var): + if t in mapping.keys(): + return mapping[t] + else: + return t + + elif isinstance(t, TensorType): + new_type = [] + for typ in t.__args__: + if typ in mapping.keys(): + new_type.append(mapping[typ]) + else: + new_type.append(typ) + return TensorType(tuple(new_type)) + + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + + +def substitute_all_types(graph, mapping): + """ + Apply the most general unifier to all types in a graph + till reaching a fixed point. If the input and output graph + are the same, we converge. + """ + flag = True + while flag: + flag = False + for k in mapping: + old_mapping_val = mapping[k] + if mapping[k] in mapping.keys(): + new_key = mapping[k] + mapping[k] = mapping[new_key] + if old_mapping_val != mapping[k]: + flag = True + + for n in graph.nodes: + n.type = substitute_solution_one_type(mapping, n.type) + +def check_for_type_equality(g1, g2): + """ + A check equality to be used in fixed points. + We do not use graph equality but instead type + equality. + """ + for n, m in zip(g1.nodes, g2.nodes): + if n.type != m.type: + return False + return True diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..163479d3cd4c6fba17d469d6f9c04c21b4818953 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/validator.py @@ -0,0 +1,787 @@ +# mypy: allow-untyped-defs +import functools +import logging +import math +import operator +import sympy +import builtins + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback + +from torch._dynamo.exc import TorchDynamoException +from torch.fx.node import Argument, Target +from torch.utils._sympy.interp import sympy_interp +from torch._dynamo.utils import dynamo_timed + +log = logging.getLogger(__name__) + +try: + import z3 # type: ignore[import] + + # Translation Validation for Dynamo guards + # ======================================== + # + # Checks whether optimizations applied to the collected guards are + # valid. In other words, whether the guard function we actually run + # does not have false positives (unsound). + # + # In order to do so, we build the guards using 2 different information + # attached to each 'SymNode': + # 1. SymPy expressions + # 2. FX nodes + # + # SymPy expressions have implicit optimizations baked within itself, + # which may have a few bugs. On the other hand, we build the FX graph + # manually, with no optimizations enabled. This gives us access to + # the "ground truth". + # + # We then convert into Z3 expressions both the SymPy expressions + # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function + # and the FX nodes (see [Note: PopulateValidator]) that go through + # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. + # (see [Note: TranslationValidator]) + + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> List[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + + # Implementation of Python semantics as Z3 expressions. + # + # Z3 Real-Int theory has operators with semantics that differ that of + # Python. Therefore, in order to get it right, we need to implement + # the (Python) semantics we are relying on in Z3. + @dataclass + class _Z3Ops: + # Validator used for adding assertions as needed. + # e.g. div(a, b) requires b != 0. + validator: "TranslationValidator" + + # The 2 functions below are used for conditionally casting between + # integer and reals. + # + # Returns a real expression from 'x'. + @staticmethod + def to_real(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_real() else z3.ToReal(x) + + # Returns an integer expression from 'x'. + @staticmethod + def to_int(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_int() else z3.ToInt(x) + + # Implements Python division semantics. + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] + return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) + + def floor(self, number: z3.ArithRef) -> z3.ArithRef: + # Z3 ToInt function rounds a real number towards negative infinity. + return _Z3Ops.to_int(number) + + # Python semantics for 'FloorDiv' states that before applying the floor + # function, the operands are converted to their common type. + def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + cast_result_to_real = numerator.is_real() or denominator.is_real() + result = _Z3Ops.to_int(self.div(numerator, denominator)) + # Since the 'result' is already an integer, we just have to check + # whether we should cast it to real. + return _Z3Ops.to_real(result) if cast_result_to_real else result + + def ceil(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If( + self.floor(number) < number, + self.floor(number + 1), + number + ) # type: ignore[return-value] + + def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a > b, a, b) # type: ignore[return-value] + + def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a < b, a, b) # type: ignore[return-value] + + # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q + # It should work with both integer and reals. + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return p - self.floordiv(p, q) * q + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + # Z3 can't handle complex numbers very well. + self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] + return base ** exp + + def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: + # Square-root: + # 1. Only work with reals + number = _Z3Ops.to_real(number) + # 2. The number should be positive or zero. + # Otherwise, Z3 returns 'unknown'. + self.validator.add_assertion(number >= 0) + return number ** 0.5 + + def abs(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.Abs(number) + + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + # Pythons builtin 'round' implements the 'round half to even' strategy + # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to + # floating point numbers, which is different from real numbers that we are dealing with here. + # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and + # 'round half down' (ceil(x - 0.5)). + # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... + # to round down, i.e. use the 'round half down' strategy + return z3.If( + self.mod(number, z3.IntVal(2)) == 0.5, + self.ceil(number - 0.5), + self.floor(number + 0.5), + ) + + # Lifts a callable to be used in Z3. + # + # This function replaces the given 'op' by a function that: + # + # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) + # + # 2. Calls an operation that corresponds to 'op', but works with Z3 + # inhabitants (left as is if it works as is) + def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + # Operations that have booleans as their argument. + # This is needed because the argument of some FX nodes were + # literal integers, instead of booleans. So, whenever this flag + # is set, we also convert ints to booleans. + boolean_ops = {operator.not_, operator.and_, operator.or_} + as_bool = op in boolean_ops + + # Lifts the function into 'z3.ExprRef' domain. + def lift(func): + def wrap(a) -> z3.ExprRef: + if isinstance(a, (z3.ArithRef, z3.BoolRef)): + return a + # Convert it into a Z3 value, if it is some of the supported + # types below. + if isinstance(a, bool) or (as_bool and isinstance(a, int)): + return z3.BoolVal(bool(a)) + if isinstance(a, (int, sympy.Integer)): + return z3.IntVal(int(a)) + if isinstance(a, (float, sympy.Float)): + return z3.RealVal(float(a)) + raise ValueError(f"can't lift type: {type(a)}") + + @functools.wraps(func) + def wrapper(*args): + # Lifts the arguments into a list of Z3 inhabitants. + wrapped_args = (wrap(a) for a in args) + # Run the function on the Z3 expressions. + return func(*wrapped_args) + + return wrapper + + ops = _Z3Ops(validator) + replacement_map = { + # Operator module. + operator.not_: lift(z3.Not), + operator.and_: lift(z3.And), + operator.or_: lift(z3.Or), + operator.floordiv: lift(ops.floordiv), + operator.truediv: lift(ops.div), + operator.mod: lift(ops.mod), + operator.abs: lift(ops.abs), + builtins.round: lift(ops.round_to_int), + + # Math module. + math.ceil: lift(ops.ceil), + math.floor: lift(ops.floor), + + # Torch module. + torch.sym_float: lift(ops.to_real), + torch.sym_max: lift(ops.max), + torch.sym_min: lift(ops.min), + torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] + # Not lifted because we only use this function as a + # marker for adding the expression as validator input. + torch._assert: torch._assert, + } + return replacement_map[op] if op in replacement_map else lift(op) + + # Processes an FX graph, populating the given validator. + # + # [Note: PopulateValidator] + # This class walks through each node in the FX graph, translating + # them into the Z3 world. + # + # Then, whenever it finds an 'torch._assert' call_function operation, + # it adds the Z3 expression corresponding to the argument as validator + # input. + class PopulateValidator(torch.fx.Interpreter): + def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + # Reference to the translation validator. + self.validator = validator + + # Build the graph module and call `Interpreter` constructor. + module = torch.fx.GraphModule(root={}, graph=graph) + super().__init__(module, garbage_collect_values=True) + + def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + symbol = fx_traceback.get_current_meta()["symbol"] + return self.validator.z3var(symbol) + + def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + if target != torch._assert: + # Lift and runs the node target function + return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] + # Adds the Z3 expression corresponding to the first argument + # as a validator input. + assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} " + self.validator.add_source_expr(args[0]) # type: ignore[arg-type] + + # Translates SymPy expressions into Z3 expressions. + # + # [Note: SympyToZ3] + # At the time of the translation, all free variables present in the + # SymPy expression being translated must be already mapped to a Z3 + # integer variable. + class SympyToZ3: + OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} + + def __init__( + self, + validator: "TranslationValidator", + ) -> None: + self._validator = validator + self._ops = _Z3Ops(self._validator) + + def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision + if dtype is torch.int64: + return z3.IntVal(int(value)) + if dtype is torch.double: + return z3.RealVal(float(value)) + if dtype is torch.bool: + return z3.BoolVal(bool(value)) + raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) + + def __getattr__(self, name: str) -> Any: + REPLACEMENT = { + "and_": z3.And, + "or_": z3.Or, + "not_": z3.Not, + "floor": self._ops.floor, + "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, + } + + if name in REPLACEMENT: + return REPLACEMENT[name] + if name in self.OPERATOR_HANDLES: + return getattr(operator, name) + raise AttributeError(f"unhandled operator: {name}") + + def run(self, expr: sympy.Basic) -> z3.ExprRef: + return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] + + # Dynamo guards translation validator. + # + # [Note: TranslationValidator] + # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. + # That is: whether those (target) guards only yield TRUE whenever the original, + # unoptimized, (source) guards yield TRUE. + # + # More concretely, given 'source' and 'target' guard expressions, we wish to + # check whether the following expression holds: + # + # Not(And(source)) AND And(target) + # + # i.e. whether there is an assignment of the free variables where the opposite + # happens: target is TRUE, but source is FALSE. + class TranslationValidator: + def __init__(self) -> None: + log.debug("new instance") + + # Mapping of SymPy symbols to Z3 variables. + self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} + + # Set of source Z3 expressions. + # They represent the generated guards without any kind of + # simplification or transformation. + self._source_exprs: Set[z3.BoolRef] = set() + + # Set of target Z3 expressions. + # They represent the actual checked guards at runtime. They might + # be simplified or transformed versions of the source guards. + self._target_exprs: Set[z3.BoolRef] = set() + + # Set of Z3 expressions representing assertions over both the + # source and target expressions. + self._assertions: Set[z3.BoolRef] = set() + + # Retrieves the corresponding Z3 variable. + def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: + assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" + return self.symbols[symbol] + + # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. + def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: + if symbol in self.symbols: + return self.symbols[symbol] + + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + + if type is int: + var = z3.Int(symbol.name) + + # If 'symbol' is positive (SymPy assumption), we have to + # convey it to Z3 as well. + if symbol.is_positive: # type: ignore[attr-defined] + self._target_exprs.add(var > 0) + elif type is float: + var = z3.Real(symbol.name) + elif type is bool: + var = z3.Bool(symbol.name) + else: + raise RuntimeError(f"unsupported type for Z3 variable: {type}") + + self.symbols[symbol] = var + return var + + # Checks whether all symbols were already added. + def _check_freesymbols(self, e: sympy.Basic) -> None: + for s in e.free_symbols: + assert isinstance(s, sympy.Symbol) + # Call 'z3var' just to check whether there's already a + # Z3 variable corresponding to 's'. + self.z3var(s) + + + def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: + z3expr = SympyToZ3(self).run(e) + assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}" + return z3expr + + def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) + self._source_exprs.add(e) + + def add_target_expr(self, e: sympy.Expr) -> None: + self._check_freesymbols(e) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) + + def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: + if isinstance(e, sympy.Basic): + self._check_freesymbols(e) + ref = self.to_z3_boolean_expr(e) + else: + ref = e + assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) + self._assertions.add(ref) + + def validate(self) -> None: + with dynamo_timed("TranslationValidator.validate"): + return self._validate() + + def _validate(self) -> None: + if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: + # If there are no source/target expressions, there's nothing we really + # wish to prove. So, we just return. + return None + + # Here, we use "QF_NRA" logic for the solver: + # "Quantifier-free Non-linear Real Arithmetic". + # + # Most of the guards expressions have: + # 1. arithmetic between integer and reals + # 2. no quantifiers + # 3. potentially non-linear. + # + # Although there's also "QF_NIRA" (mixed integer-real arithmetic), + # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. + solver = z3.SolverFor("QF_NRA") + # Set a timeout for finding a solution. + solver.set(timeout=translation_validation_timeout()) + + # Add all the assertions to the solver. + for assertion in self._assertions: + solver.add(assertion) + + # "Is there any case where it's TRUE for the target expressions, + # but FALSE for the source expressions?" + solver.add(z3.Not(z3.And(*self._source_exprs))) + solver.add(*self._target_exprs) + + log.debug("translation validation: start") + r = solver.check() + if r == z3.sat: + # Target expressions are unsound. + # Log the found model and the source expressions that failed. + model = solver.model() + raise ValidationException( + model, self._assertions, self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ] + ) + else: + if r == z3.unknown: + # Could not find a solution. It didn't fail, but it also + # didn't succeed. Canceling the validation execution (keyboard + # interrupt) also gets to this branch. + log.warning("translation validation: could not validate: got z3.unknown") + else: + # Target expressions are sound. + assert r == z3.unsat + log.debug("translation validation: success") + +except ImportError: + _HAS_Z3 = False + + __all__ = [ + "translation_validation_enabled", "translation_validation_timeout", + "ValidationException", "BisectValidationException", + ] + +else: + _HAS_Z3 = True + + __all__ = [ + "z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator", + "translation_validation_enabled", "translation_validation_timeout", + "ValidationException", "BisectValidationException", + ] + +from torch.fx.experimental import _config as config + +def translation_validation_enabled() -> bool: + # Checks everytime this function is called, in case the Dynamo + # option is set, but Z3 is not installed. + _assert_z3_installed_if_tv_set() + return _HAS_Z3 and config.translation_validation + + +def translation_validation_timeout() -> int: + return config.translation_validation_timeout + + +def _assert_z3_installed_if_tv_set(): + assert _HAS_Z3 or not config.translation_validation, ( + "translation validation requires Z3 package. Please, either install " + "z3-solver or disable translation validation." + ) + + +class ValidationException(TorchDynamoException): + def __init__(self, model, assertions, target_exprs, failed_source_exprs): + assert _HAS_Z3 + + def symbolstr(sym) -> str: + return f"{sym}: {model[sym]}" + + def joinlines(xs) -> str: + return "\n".join(f" ==> {x}" for x in xs) + + model_str = joinlines(sorted(map(symbolstr, model))) + assertions_str = joinlines(sorted(map(z3str, assertions))) + target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) + failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) + + self.msg = "translation validation failed." + self.details = f"""\ +Model: +{model_str} + +Assertions: +{assertions_str} + +Target Expressions: +{target_exprs_str} + +Failed Source Expressions: +{failed_source_exprs_str}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +class BisectValidationException(TorchDynamoException): + def __init__(self, validation_exc, expr, failed_action, traced_node): + self.msg = f"translation validation failed when {failed_action}: {expr}" + self.details = f"""\ +Failure occurred while running node: + {traced_node.format_node()} + +{validation_exc.details}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + +# Checks when this module is loaded. +_assert_z3_installed_if_tv_set() + +# Translation validation bisection. +# +# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise +# the earliest ValidationException. +# +# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors +# might be silently happening. This function tries to nail down exactly at which +# point things went wrong from a validation perspective. +def bisect(shape_env): + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY + from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events + + events = shape_env.events + + # Retrieves the ShapeEnvEvent associated with node. + def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: + assert SHAPEENV_EVENT_KEY in node.meta + return events[node.meta[SHAPEENV_EVENT_KEY]] + + # Creates a new instance of fake, but updating every symbolic value's ShapeEnv + # reference to the one given as argument. + # + # This is needed so as not to simplify a symbolic expression using a ShapeEnv + # "from the future", where it may have a different set of replacements. + def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + if isinstance(fake, int): + return fake + if isinstance(fake, torch.SymInt): + return torch.SymInt(fake.node.with_shape_env(shape_env)) + assert isinstance(fake, FakeTensorMeta) + return FakeTensorMeta( + tuple(new_with_shape_env(shape_env, s) for s in fake.size()), + tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), + new_with_shape_env(shape_env, fake.storage_offset()), + fake.is_nested, + ) + + # Checks whether the given shape_env fails when produce_guards is called. + def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]: + assert tracked_fakes is not None + try: + # This produce_guards call is a best-effort replication, since we + # don't populate EqualityConstraint list. Reason: we would also have + # to save OutputGraph.tracked_fakes_id_to_source. + shape_env.produce_guards( + [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], + [a.source for a in tracked_fakes], + input_contexts=[a.symbolic_context for a in tracked_fakes], + ) + return None + except ValidationException as e: + return e + + # Checks whether the ShapeEnv reconstructed by replaying the events until + # node is created fails when produce_guards is called. + def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: + number = node.meta[SHAPEENV_EVENT_KEY] + # Reconstruct shape_env until the event at event_number. + shape_env = replay_shape_env_events(events[:number + 1]) + shape_env.graph.lint() + return check_shapeenv_fails(shape_env, events[number].tracked_fakes) + + last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes()) + + if not last_exception: + # We don't actually fail due to a produce_guards call. + # Stop and don't bisect. + log.info("translation validation succeeded: no errors found.") + return + + if not shape_env.should_record_events or config.translation_validation_no_bisect: + # Bisection is off. + # Return the last ValidationException we got. + raise last_exception + + # Cache the raised exception (if any) at each bisection point. + exception = {} + + # Bisection happens on the assertion nodes of the recorded FX graph for + # dynamic shapes. + assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert] + + # Preparing the indices for binary search. + left, mid, right = 0, 0, len(assert_nodes) - 1 + + while left < right: + mid = (left + right) // 2 + + node = assert_nodes[mid] + log.debug("bisecting at %s: %s", mid, get_node_event(node)) + + # Check whether the new shape_env raises a ValidationException or not. + exception[mid] = check_node_fails(node) + + if exception[mid]: + right = mid + else: + left = mid + 1 + + assert left in exception and isinstance(exception[left], ValidationException) + + node = assert_nodes[left] + event = get_node_event(node) + + if event.is_evaluate_expr(): + failed_action = "evaluating" + else: + assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" + failed_action = "adding runtime assert" + + args = event.args + assert args is not None + assert len(args) >= 2, ( + f"bisecting expects {event.name} to have at least 2 positional arguments. " + f"Got: {len(args)}" + ) + assert isinstance(args[1], sympy.Basic), ( + f"bisecting expects {event.name} to have a SymPy expression as its second argument. " + f"Got: {type(args[1])}" + ) + + raise BisectValidationException( + exception[left], + expr=args[1], + failed_action=failed_action, + traced_node=node.meta[CURRENT_NODE_KEY], + ) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/graph.py b/.venv/lib/python3.11/site-packages/torch/fx/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fdd79fcbb280236c0dd139fa2baffa5c53a061 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/graph.py @@ -0,0 +1,1796 @@ +# mypy: allow-untyped-defs +from collections import defaultdict +from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name +import torch.utils._pytree as pytree +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from torch._C import _NodeIter + +import os +import contextlib +from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable +from dataclasses import dataclass +from contextlib import contextmanager +import copy +import enum +import torch +import keyword +import re +import builtins +import math +import warnings +import inspect + +__all__ = ["PythonCode", "CodeGen", "Graph"] + +if TYPE_CHECKING: + from .graph_module import GraphModule # noqa: F401 + from ._symbolic_trace import Tracer # noqa: F401 + + +# Mapping of builtins to their `typing` equivalent. +_origin_type_map = { + list: List, + dict: Dict, + set: Set, + frozenset: FrozenSet, + tuple: Tuple, +} + + +# Signature for functions thattransforms the body (`list[str]`) of the +# generated code +TransformCodeFunc = Callable[[List[str]], List[str]] + + +class _CustomBuiltin(NamedTuple): + """Additional objs that we add to every graph's globals. + + The repr() for some standard library objects is not valid Python code without + an import. For common objects of this sort, we bundle them in the globals of + every FX graph. + """ + # How to import this object from the standard library. + import_str: str + # The actual object, produced from that import string. + obj: Any + +_custom_builtins: Dict[str, _CustomBuiltin] = {} + + +def _register_custom_builtin(name: str, import_str: str, obj: Any): + _custom_builtins[name] = _CustomBuiltin(import_str, obj) + + +_register_custom_builtin('inf', 'from math import inf', math.inf) +_register_custom_builtin('nan', 'from math import nan', math.nan) +_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) +_register_custom_builtin('torch', 'import torch', torch) +_register_custom_builtin('device', 'from torch import device', torch.device) +_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) +_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) + + +def _is_magic(x: str) -> bool: + return x.startswith('__') and x.endswith('__') + + +def _snake_case(s: str) -> str: + """ + Transforms the given string ``s`` to a Python-style variable name + + Examples: + ``mod.snake_case`` -> ``mod.snake_case`` + ``mod.pascalCase``-> ``mod.pascal_case`` + ``mod.ALL_CAPS`` -> ``mod.all_caps`` + """ + chars = [] + prev_lower = False + for c in s: + if prev_lower and c.isupper(): + chars.append('_') + chars.append(c.lower()) + prev_lower = c.islower() + return ''.join(chars) + + +def _is_from_torch(obj: Any) -> bool: + module_name = getattr(obj, '__module__', None) + if module_name is not None: + base_module = module_name.partition('.')[0] + return ( + base_module == 'torch' and + not module_name.startswith("torch._dynamo.") and + not module_name.startswith("torch._inductor.") + ) + + name = getattr(obj, '__name__', None) + # exclude torch because torch.torch.torch.torch works. idk mang + if name is not None and name != 'torch': + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is obj: + return True + + return False + + +class _Namespace: + """A context for associating names uniquely with objects. + + The following invariants are enforced: + - Each object gets a single name. + - Each name is unique within a given namespace. + - Names generated do not shadow builtins, unless the object is indeed that builtin. + """ + def __init__(self): + self._obj_to_name: Dict[Any, str] = {} + self._unassociated_names = set() + self._used_names: Set[str] = set() + self._base_count: Dict[str, int] = defaultdict(int) + + self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') + self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") + + def create_name(self, candidate: str, obj: Optional[Any]) -> str: + """Create a unique name. + + Arguments: + candidate: used as the basis for the unique name, relevant to the user. + obj: If not None, an object that will be associated with the unique name. + """ + if obj is not None and obj in self._obj_to_name: + return self._obj_to_name[obj] + + # delete all characters that are illegal in a Python identifier + candidate = self._illegal_char_regex.sub('_', candidate) + + if not candidate: + candidate = '_unnamed' + + if candidate[0].isdigit(): + candidate = f'_{candidate}' + + match = self._name_suffix_regex.match(candidate) + if match is None: + base = candidate + num = None + else: + base, num_str = match.group(1, 2) + num = int(num_str) + + candidate = base if num is None else f'{base}_{num}' + if not num: + num = self._base_count[base] + + while candidate in self._used_names or self._is_illegal_name(candidate, obj): + num += 1 + candidate = f'{base}_{num}' + + self._used_names.add(candidate) + self._base_count[base] = num + if obj is None: + self._unassociated_names.add(candidate) + else: + self._obj_to_name[obj] = candidate + return candidate + + def associate_name_with_obj(self, name: str, obj: Any): + """Associate a unique name with an object. + + Neither `name` nor `obj` should be associated already. + """ + assert obj not in self._obj_to_name + assert name in self._unassociated_names + self._obj_to_name[obj] = name + self._unassociated_names.remove(name) + + def _is_illegal_name(self, name: str, obj: Any) -> bool: + # 1. keywords are never allowed as names. + if name in keyword.kwlist: + return True + + # 2. Can't shadow a builtin name, unless you *are* that builtin. + if name in builtins.__dict__: + return obj is not builtins.__dict__[name] + + # 3. Can't shadow our custom builtins either + if name in _custom_builtins: + return obj is not _custom_builtins[name].obj + + return False + + def _rename_object(self, obj: Any, name: str): + assert obj in self._obj_to_name + self._obj_to_name[obj] = name + self._used_names.add(name) + +dtype_abbrs = { + torch.bfloat16: 'bf16', + torch.float64: 'f64', + torch.float32: 'f32', + torch.float16: 'f16', + torch.float8_e4m3fn: 'f8e4m3fn', + torch.float8_e5m2: 'f8e5m2', + torch.float8_e4m3fnuz: 'f8e4m3fnuz', + torch.float8_e5m2fnuz: 'f8e5m2fnuz', + torch.complex32: 'c32', + torch.complex64: 'c64', + torch.complex128: 'c128', + torch.int8: 'i8', + torch.int16: 'i16', + torch.int32: 'i32', + torch.int64: 'i64', + torch.bool: 'b8', + torch.uint8: 'u8', + torch.uint16: 'u16', + torch.uint32: 'u32', + torch.uint64: 'u64', + torch.bits16: 'b16', +} + +@compatibility(is_backward_compatible=True) +@dataclass +class PythonCode: + """ + Represents all the information necessary to exec or save a graph as Python code. + """ + # Python source code for the forward function definition. + src: str + # Values in global scope during execution of `src_def`. + globals: Dict[str, Any] + # Optional mapping from the forward function's line number to + # node index. + _lineno_map: Optional[Dict[int, Optional[int]]] + + +def _format_target(base: str, target: str) -> str: + elems = target.split('.') + r = base + for e in elems: + if not e.isidentifier(): + r = f'getattr({r}, "{e}")' + else: + r = f'{r}.{e}' + return r + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + +class _node_list: + def __init__(self, graph: 'Graph', direction: str = '_next'): + assert direction in ['_next', '_prev'] + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + assert self.direction == "_prev" or self.direction == "_next" + yield from _NodeIter(self.graph._root, self.direction == "_prev") + + def __reversed__(self): + return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') + +class _PyTreeInfo(NamedTuple): + """ + Contains extra info stored when we're using Pytrees + """ + orig_args: List[str] + in_spec: pytree.TreeSpec + out_spec: Optional[pytree.TreeSpec] + +@dataclass(frozen=True) +class _ParsedStackTrace: + """ + Represents the top-most frame of a parsed stack trace + """ + file: str + lineno: str + name: str + code: str + + def get_summary_str(self): + return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' + +# get File:lineno code from stack_trace +def _parse_stack_trace(stack_trace: str): + if stack_trace is None: + return None + pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") + lines = stack_trace.strip().split('\n') + # stacktrace should have innermost frame last, so we + # iterate backwards to find the first line that starts + # with 'File ' + summary_str = "" + for idx in range(len(lines) - 2, -1, -1): + line = lines[idx].strip() + matches = pattern.match(line) + if matches: + file = matches.group(1) + lineno = matches.group(2) + name = matches.group(3) + # next line should be the code + code = lines[idx + 1].strip() + return _ParsedStackTrace(file, lineno, name, code) + return None + +@compatibility(is_backward_compatible=False) +class CodeGen: + def __init__(self): + self._body_transformer: Optional[TransformCodeFunc] = None + self._func_name: str = "forward" + + def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: + """ + Given the free variables and a return annotation, generates the beginning of the FX function. + By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` + """ + # If the original function didn't have self as its first argument, we + # would have added it. + if len(free_vars) == 0 or free_vars[0] != 'self': + free_vars.insert(0, 'self') + return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + + def generate_output(self, output_args: Argument) -> str: + """ + Given the output arguments, generates the return statement of the FX function. + Note: The returned statement should not be indented. + """ + return f'return {repr(output_args)}' + + def process_inputs(self, *args: Any) -> Any: + """ + Transforms the inputs so that the graph can take them as arguments, as + non-default codegen may result in the inputs to the function being + different from the inputs to the graph. + + If the graph was directly runnable, this invariant should hold true + `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` + """ + return args + + def process_outputs(self, outputs: Any) -> Any: + """ + Transforms the outputs of the graph to be identical to the codegen. + + See ``process_inputs`` for more details. + """ + return outputs + + def additional_globals(self) -> List[Tuple[str, Any]]: + """ + If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. + For example, return ['List', typing.List] if you need ``List`` in the global context. + """ + return [] + + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation : List[str] = [''] + include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") + include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o : Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + codes = { + "yellow": "\033[33m", + "cyan": "\033[36m", + "green": "\033[32m", + "blue": "\033[34m", + "red": "\033[31m", + "dim": "\033[2m", + "dim_blue": "\033[2m\033[34m", + "dim_green": "\033[2m\033[32m", + "reset": "\033[0m", + } + + def make_wrapper_func(name): + def f(s): + if colored: + return f"{codes[name]}{s}{codes['reset']}" + return s + return f + + yellow = make_wrapper_func("yellow") + cyan = make_wrapper_func("cyan") + red = make_wrapper_func("red") + green = make_wrapper_func("green") + dim_green = make_wrapper_func("dim_green") + dim = make_wrapper_func("dim") + dim_blue = make_wrapper_func("dim_blue") + blue = make_wrapper_func("blue") + + def _get_repr(arg: Any) -> str: + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + qualified_name = _get_qualified_name(arg) + global_name = add_global(qualified_name, arg) + return f"{global_name}" + elif isinstance(arg, enum.Enum): + cls = arg.__class__ + clsname = add_global(cls.__name__, cls) + return f"{clsname}.{arg.name}" + elif isinstance(arg, Node): + return repr(arg) + elif isinstance(arg, torch.Tensor): + size = list(arg.size()) + dtype = str(arg.dtype).split(".")[-1] + return f"torch.Tensor(size={size}, dtype={dtype})" + else: + return blue(repr(arg)) + + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + def delete_unused_values(user : Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + + if len(user.users.keys()) == 0: + # This node is not used by any others. however it's also not + # removed by DCE since side-effect. We want to free it's outputs + # right after its execution done to save memory. + nodes_to_delete.append(user) + + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {dim(to_delete_str)}\n') + else: + body.append('\n') + + prev_stacktrace = None + + def append_stacktrace_summary(node : Node): + """ + Append a summary of the stacktrace to the generated code. This is + useful for debugging. + """ + nonlocal prev_stacktrace + + if node.op not in {'placeholder', 'output'}: + if node.stack_trace: + if node.stack_trace != prev_stacktrace: + prev_stacktrace = node.stack_trace + summary_str = "" + + if parsed_stack_trace := _parse_stack_trace(node.stack_trace): + summary_str = parsed_stack_trace.get_summary_str() + + body.append(f'\n {dim("# " + summary_str)}\n') + elif prev_stacktrace != "": + prev_stacktrace = "" + no_stacktrace_msg = "# No stacktrace found for following nodes" + body.append(f'\n{dim(no_stacktrace_msg)}\n') + + def stringify_shape(shape : Iterable) -> str: + return f"[{', '.join(str(x) for x in shape)}]" + + def emit_node(node : Node): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + + if verbose: + # override annotation with more detailed information + from torch.fx.experimental.proxy_tensor import py_sym_types + from torch.fx.passes.shape_prop import TensorMetadata + + meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) + # use string as annotation, to make it valid python code + + if isinstance(meta_val, torch.Tensor): + stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" + device_annotation = f"{meta_val.device}" if include_device else "" + maybe_type_annotation = \ + f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ + f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' + elif isinstance(meta_val, py_sym_types): + maybe_type_annotation = f': "Sym({meta_val})"' + elif isinstance(meta_val, TensorMetadata): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') + return + body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + for i, node in enumerate(nodes): + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + if verbose: + append_stacktrace_summary(node) + # emit a counter comment to keep track of + # node index, which will be deleted later + # after going through _body_transformer + body.append(f"# COUNTER: {i}\n") + emit_node(node) + delete_unused_values(node) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + + # remove counter and generate lineno to node index mapping + lineno_map: Dict[int, Optional[int]] = {} + prologue_len = prologue.count('\n') + 1 + new_lines: List[str] = [] + cur_idx = None + for line in ''.join(body).split('\n'): + counter = re.search(r"# COUNTER: (\d+)", line) + if counter and counter.group(1) is not None: + cur_idx = int(counter.group(1)) + else: + lineno_map[len(new_lines) + prologue_len] = cur_idx + new_lines.append(line) + + code = "\n".join(new_lines).lstrip('\n') + code = '\n'.join(' ' + line for line in code.split('\n')) + + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + + +# Ideally, we'd like to refactor all of the pytree logic into this codegen +# class. Unfortunately, there are 3 areas we currently need extra logic in FX. +# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. +# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. +# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. +# 3. We currently can't register the pytree imports with `add_global` - not sure why. +class _PyTreeCodeGen(CodeGen): + def __init__(self, pytree_info: _PyTreeInfo): + super().__init__() + self.pytree_info: _PyTreeInfo = pytree_info + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = pytree.arg_tree_leaves(*inputs) + return flat_args + + def process_outputs(self, out: Any) -> Any: + if self.pytree_info is None or self.pytree_info.out_spec is None: + return out + if not isinstance(out, (list, tuple)): + out = [out] + assert self.pytree_info.out_spec is not None + return pytree.tree_unflatten(out, self.pytree_info.out_spec) + + def gen_fn_def(self, free_vars, maybe_return_annotation): + # Given a user function/model: + # myargs = [myargs0, myargs1] + # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} + # def forward(self, mypos, *myargs, mykey=None, **mykwargs): + # + # The generated code flattens all keywords into positional arguments for `forward()` + # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): + # + # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately + # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], + # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), + # self._in_spec) + # + # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec + # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) + if self.pytree_info is None: + return super().gen_fn_def(free_vars, maybe_return_annotation) + + fn_args = self.pytree_info.orig_args + has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False + if has_orig_self: + free_vars.insert(0, 'self') + fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) + + if len(free_vars) > 0: # pytree has placeholders in it + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ + self.pytree_info.in_spec.num_children == 2 and \ + self.pytree_info.in_spec.children_specs[0].type == tuple and \ + self.pytree_info.in_spec.children_specs[1].type == dict + fn_kwargs = '{}' + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = self.pytree_info.in_spec.children_specs[0].num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:])) + '}' + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0] for x in free_vars] + has_annotation = [x + "; " for x in free_vars if ":" in x] + if len(has_annotation) > 0: + fn_definition += "\n " + "".join(has_annotation) + "\n" + fn_definition += f""" + {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return fn_definition + + def generate_output(self, output_args): + if self.pytree_info and self.pytree_info.out_spec: + return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' + else: + return super().generate_output(output_args) + +class _FindNodesLookupTable: + """ + Side table for the graph for the purpose of doing fast queries + """ + def __init__(self): + self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) + + def _key(self, node) -> Tuple[str, Optional[Target]]: + return (node.op, node.target if node.op == "call_function" else None) + + def __contains__(self, node) -> bool: + return node in self.table[self._key(node)] + + def insert(self, node: Node) -> None: + self.table[self._key(node)][node] = None + + def remove(self, node: Node) -> None: + self.table[self._key(node)].pop(node) + + def find_nodes(self, *, op: str, target: Optional['Target'] = None): + if op == "call_function": + assert target is not None + return dict(self.table[(op, target)]).keys() + + if target is None: + return dict(self.table[(op, None)]).keys() + + # op is call_method, get_attr, call_module + return [node for node in self.table[(op, None)].keys() if node.target == target] + +@compatibility(is_backward_compatible=True) +class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [num_users=1] = self.linear.weight + %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, + tracer_extras: Optional[Dict[str, Any]] = None): + """ + Construct an empty Graph. + """ + self._root : Node = Node(self, '', 'root', '', (), {}) + self._used_names : Dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 + self._graph_namespace = _Namespace() + self._owning_module = owning_module + self._tracer_cls = tracer_cls + self._tracer_extras = tracer_extras + self._codegen = CodeGen() + self._co_fields : Dict[str, Any] = {} + self._find_nodes_lookup_table = _FindNodesLookupTable() + + @property + def owning_module(self): + return self._owning_module + + @owning_module.setter + def owning_module(self, mod: Optional["GraphModule"]): + self._owning_module = mod + + @property + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) + + @compatibility(is_backward_compatible=False) + def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): + """ + Allows for fast query of nodes + + Args: + + op (str): the name of the operation + + target (Optional[Target]): the target of the node. For call_function, + the target is required. For other ops, the target is optional. + + sort (bool): whether to return nodes in the order they appear on + on the graph. + + Returns: + + Iteratable of nodes with the requested op and target. + """ + node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) + if sort: + return sorted(node_list) + return node_list + + @compatibility(is_backward_compatible=True) + def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': + """ + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. + """ + for node in g.nodes: + if node in val_map: + continue + if node.op == 'output': + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv if not return_output_node else (rv, node) + val_map[node] = self.node_copy(node, lambda n : val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> 'Graph': + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation. + """ + memo = memo if memo else {} + g = Graph(tracer_cls=self._tracer_cls) + output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + g._codegen = copy.deepcopy(self._codegen) + assert isinstance(output_vals, tuple) + output_val, old_output_node = output_vals + new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) + new_output_node.meta = copy.copy(old_output_node.meta) + return g + + @compatibility(is_backward_compatible=True) + def create_node(self, op: str, target: 'Target', + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted node. + """ + assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') + args = () if args is None else args + kwargs = {} if kwargs is None else kwargs + assert isinstance(args, tuple), "args must be a tuple" + assert isinstance(kwargs, dict), "kwargs must be a dict" + + candidate = name if name is not None else self._target_to_str(target) + name = self._graph_namespace.create_name(candidate, None) + n = Node(self, name, op, target, args, kwargs, type_expr) + + if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: + for f in self.owning_module._create_node_hooks: + f(n) + + self._graph_namespace.associate_name_with_obj(name, n) + + self._insert(n) + self._find_nodes_lookup_table.insert(n) + self._len += 1 + return n + + @compatibility(is_backward_compatible=False) + def process_inputs(self, *args): + """ + Processes args so that they can be passed to the FX graph. + """ + return self._codegen.process_inputs(*args) + + @compatibility(is_backward_compatible=False) + def process_outputs(self, out): + return self._codegen.process_outputs(out) + + + @compatibility(is_backward_compatible=True) + def erase_node(self, to_erase : Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' + f'users in the graph: {to_erase.users}!') + if to_erase.graph != self: + raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") + if to_erase._erased: + warnings.warn(f"erase_node({to_erase}) on an already erased node") + return + + if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: + for f in self.owning_module._erase_node_hooks: + f(to_erase) + + self._find_nodes_lookup_table.remove(to_erase) + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + new_args = map_arg(to_erase.args, lambda n: None) + assert isinstance(new_args, tuple) + to_erase.args = new_args + new_kwargs = map_arg(to_erase.kwargs, lambda n: None) + assert isinstance(new_kwargs, dict) + to_erase.kwargs = new_kwargs + + @compatibility(is_backward_compatible=True) + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + @compatibility(is_backward_compatible=True) + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + @compatibility(is_backward_compatible=True) + def placeholder(self, name: str, type_expr: Optional[Any] = None, + default_value : Any = inspect.Signature.empty) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + default_value (Any): The default value this function argument should take + on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` + should be passed as this argument to specify that the parameter does _not_ + have a default value. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + args = () if default_value is inspect.Signature.empty else (default_value,) + return self.create_node('placeholder', name, args=args, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: + module_path, _, name = qualified_name.rpartition(".") + + try: + submod: torch.nn.Module = mod.get_submodule(module_path) + except AttributeError: + warnings.warn(f"Failed to fetch module {module_path}!") + return False + + if not hasattr(submod, name): + return False + + res = getattr(submod, name) + + if (not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers): + return False + + return True + + if (self.owning_module and + not _get_attr_reference_exists(self.owning_module, qualified_name)): + warnings.warn("Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", stacklevel=2) + return self.create_node('get_attr', qualified_name, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_module(self, + module_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + if (self.owning_module and + self.owning_module.get_submodule(module_name) is None): + warnings.warn("Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule") + return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_method(self, + method_name: str, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_function(self, + the_function: Callable[..., Any], + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, + type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_function`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g : torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ + args = map_arg(node.args, arg_transform) + kwargs = map_arg(node.kwargs, arg_transform) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) + result_node.meta = copy.copy(node.meta) + return result_node + + @compatibility(is_backward_compatible=True) + def output(self, result: 'Argument', type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + + def _target_to_str(self, target : Target) -> str: + if callable(target): + op = target.__name__ + else: + assert isinstance(target, str) + op = target + if _is_magic(op): + op = op[2:-2] + op = _snake_case(op) + return op + + @compatibility(is_backward_compatible=True) + def python_code( + self, root_module: str, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + ) -> PythonCode: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + A PythonCode object, consisting of two fields: + src: the Python source code representing the object + globals: a dictionary of global names in `src` -> the objects that they reference. + """ + # NOTE: [Graph Namespaces] + # + # There are two types of symbols in generated Python source code: + # locals and globals. + # Locals are locally defined by the output of a node in the Graph. + # Globals are references to external objects, like functions or types. + # + # When generating Python code, we need to make sure to name things + # appropriately. In particular: + # - All names should be unique, to avoid weird shadowing bugs. + # - These names need to be consistent, e.g. a object should always be + # referenced by the same name. + # + # To do this, we create a new namespace just for this source. All names + # that get printed must come from this namespace. + # + # Why can't we re-use node.name? Because it was generated within the + # namespace `self._graph_namespace`. In order to provide uniqueness + # over both locals (node.name) *and* globals, we create a completely + # new namespace to put all identifiers in. + namespace = _Namespace() + + # Override Node's repr to generate a valid name within our namespace. + # Since repr() is designed to produce a valid Python expression, it + # makes sense to re-use it. This way, it's easy to print something like + # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is + # implemented cooperatively to allow this. + def node_repr(n: Node): + return namespace.create_name(n.name, n) + + @contextmanager + def override_node_repr(graph: Graph): + orig_repr_fns = {} + for node in graph.nodes: + orig_repr_fns[node] = node._repr_fn + node._repr_fn = node_repr + try: + yield None + finally: + # restore the original repr functions + for node in graph.nodes: + node._repr_fn = orig_repr_fns[node] + + with override_node_repr(self): + return self._python_code( + root_module, namespace, + verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + ) + + def _python_code( + self, root_module: str, namespace: _Namespace, *, + verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, + ) -> PythonCode: + return self._codegen._gen_python_code( + self.nodes, root_module, namespace, + verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + ) + + + def __str__(self) -> str: + """ + Return a human-readable (not machine-readable) string representation + of this Graph + """ + placeholder_names : List[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename : List[str] = [''] + + node_strs = [node.format_node(placeholder_names) for node in self.nodes] + param_str = ', '.join(placeholder_names) + s = f'graph({param_str}){maybe_return_typename[0]}:' + for node_str in node_strs: + if node_str: + s += '\n ' + node_str + return s + + @compatibility(is_backward_compatible=True) + def print_tabular(self): + """ + Prints the intermediate representation of the graph in tabular + format. Note that this API requires the ``tabulate`` module to be + installed. + """ + try: + from tabulate import tabulate + except ImportError: + print("`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library.") + raise + + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] + for n in self.nodes] + print(tabulate(node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + + @compatibility(is_backward_compatible=True) + def lint(self): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If this Graph has an owning GraphModule, checks that targets + exist in that GraphModule + """ + + # Check topo order + def check_arg(arg : Node, n : Optional[Node] = None) -> None: + context_str = f' of Node \'{n}\' ' if n else ' ' + if arg.graph is not self: + raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' + f'but was used as an argument! If you are copying nodes from another graph, make ' + f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + if arg not in seen_values: + raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' + f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + + seen_names : Set[str] = set() + seen_values : Set[Node] = set() + for node in self.nodes: + if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: + raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.graph is not self: + raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + if node not in self._find_nodes_lookup_table: + raise RuntimeError(f"Node '{node}' is not added to the side table") + map_arg(node.args, lambda arg: check_arg(arg, node)) + map_arg(node.kwargs, lambda arg: check_arg(arg, node)) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f'Node redefined name {node.name}!') + seen_names.add(node.name) + + # Check targets are legit + if self.owning_module: + num_warnings = 0 + MAX_WARNINGS = 5 + for node in self.nodes: + if node.op == 'call_function': + if not callable(node.target): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a Callable is expected') + else: + if not isinstance(node.target, str): + raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' + 'a str is expected') + if node.op in ['get_attr', 'call_module']: + target_atoms = node.target.split('.') + m_itr = self.owning_module + for i, atom in enumerate(target_atoms): + new_m_itr = getattr(m_itr, atom, None) + seen_qualname = '.'.join(target_atoms[:i]) + if new_m_itr is None: + raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' + f'{atom} of {seen_qualname}') + if (node.op == "call_module" + and not isinstance(new_m_itr, torch.nn.Module)): + raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module') + elif (node.op == "get_attr" + and not isinstance(new_m_itr, torch.nn.Module) + and not isinstance(new_m_itr, torch.nn.Parameter) + and atom not in m_itr._buffers): + if num_warnings < MAX_WARNINGS: + # Don't emit this warning too frequently, + # for very large graphs this can become very expensive + # from a performance perspective. + warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module, nn.Parameter, or buffer, which is ' + 'what \'get_attr\' Nodes typically target') + num_warnings += 1 + else: + m_itr = new_m_itr + if num_warnings > MAX_WARNINGS: + warnings.warn( + f'Additional {num_warnings - MAX_WARNINGS} warnings ' + 'suppressed about get_attr references' + ) + + @compatibility(is_backward_compatible=True) + def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): + """ + Remove all dead code from the graph, based on each node's number of + users, and whether the nodes have any side effects. The graph must be + topologically sorted before calling. + + Args: + is_impure_node (Optional[Callable[[Node], bool]]): A function that returns + whether a node is impure. If this is None, then the default behavior is to + use Node.is_impure. + + Returns: + bool: Whether the graph was changed as a result of the pass. + + Example: + + Before dead code is eliminated, `a` from `a = x + 1` below has no users + and thus can be eliminated from the graph without having an effect. + + .. code-block:: python + + def forward(self, x): + a = x + 1 + return x + self.attr_1 + + After dead code is eliminated, `a = x + 1` has been removed, and the rest + of `forward` remains. + + .. code-block:: python + + def forward(self, x): + return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations or you supply your own custom + function for detecting side-effectful nodes. + """ + # Lint the graph first to make sure its topologically sorted, otherwise + # DCE below will not behave as expected. + self.lint() + + def has_side_effect(node): + if is_impure_node is not None: + return is_impure_node(node) + return node.is_impure() + + # Reverse iterate so that when we remove a node, any nodes used as an + # input to that node have an updated user count that no longer reflects + # the removed node. + changed = False + for node in reversed(self.nodes): + if not has_side_effect(node) and len(node.users) == 0: + self.erase_node(node) + changed = True + + return changed + + @compatibility(is_backward_compatible=False) + def set_codegen(self, codegen: CodeGen): + self._codegen = codegen + + @compatibility(is_backward_compatible=False) + def on_generate_code( + self, + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] + ): + """Register a transformer function when python code is generated + + Args: + make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): + a function that returns a code transformer to be registered. + This function is called by `on_generate_code` to obtain the + code transformer. + + This function is also given as its input the currently + registered code transformer (or None if nothing is registered), + in case it is not desirable to overwrite it. This is useful to + chain code transformers together. + + Returns: + a context manager that when used in a `with` statement, to automatically + restore the previously registered code transformer. + + Example: + + .. code-block:: python + + + gm: fx.GraphModule = ... + + # This is a code transformer we want to register. This code + # transformer prepends a pdb import and trace statement at the very + # beginning of the generated torch.fx code to allow for manual + # debugging with the PDB library. + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\\n", *body] + + # Registers `insert_pdb`, and overwrites the current registered + # code transformer (given by `_` to the lambda): + gm.graph.on_generate_code( + lambda _: insert_pdb + ) + + # Or alternatively, registers a code transformer which first + # runs `body` through existing registered transformer, then + # through `insert_pdb`: + gm.graph.on_generate_code( + lambda current_trans: ( + lambda body: insert_pdb( + current_trans(body) if current_trans + else body + ) + ) + ) + + gm.recompile() + gm(*inputs) # drops into pdb + + + This function can also be used as a context manager, with the benefit to + automatically restores the previously registered code transformer: + + .. code-block:: python + + # ... continue from previous example + + with gm.graph.on_generate_code(lambda _: insert_pdb): + # do more stuff with `gm`... + gm.recompile() + gm(*inputs) # drops into pdb + + # now previous code transformer is restored (but `gm`'s code with pdb + # remains - that means you can run `gm` with pdb here too, until you + # run next `recompile()`). + """ + on_gen_code_old = self._codegen._body_transformer + self._codegen._body_transformer = make_transformer(on_gen_code_old) + + @contextlib.contextmanager + def on_generate_code_context_manager(): + try: + yield + finally: + self._codegen._body_transformer = on_gen_code_old + + return on_generate_code_context_manager() + + +reflectable_magic_methods = { + 'add': '{} + {}', + 'sub': '{} - {}', + 'mul': '{} * {}', + 'floordiv': '{} // {}', + 'truediv': '{} / {}', + 'div': '{} / {}', + 'mod': '{} % {}', + 'pow': '{} ** {}', + 'lshift': '{} << {}', + 'rshift': '{} >> {}', + 'and_': '{} & {}', + 'or_': '{} | {}', + 'xor': '{} ^ {}', + 'getitem': '{}[{}]', + 'matmul': '{} @ {}', +} + +magic_methods = dict({ + 'eq': '{} == {}', + 'ne': '{} != {}', + 'lt': '{} < {}', + 'gt': '{} > {}', + 'le': '{} <= {}', + 'ge': '{} >= {}', + 'pos': '+{}', + 'neg': '-{}', + 'invert': '~{}'}, **reflectable_magic_methods) + +inplace_methods = { + 'iadd': '{} += {}', + 'iand': '{} &= {}', + 'ifloordiv': '{} //= {}', + 'ilshift': '{} <<= {}', + 'imod': '{} %= {}', + 'imul': '{} *= {}', + 'imatmul': '{} @= {}', + 'ior': '{} |= {}', + 'ipow': '{} **= {}', + 'irshift': '{} >>= {}', + 'isub': '{} -= {}', + 'itruediv': '{} /= {}', + 'ixor': '{} ^= {}', + 'setitem': '{}[{}] = {}', +} diff --git a/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py b/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..76dac29512bdccd7cb93a5afe8107e124f8961b7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/graph_module.py @@ -0,0 +1,955 @@ +# mypy: allow-untyped-defs +import contextlib +import copy +import itertools +import linecache +import os +import sys +import traceback +import warnings +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set, Type, Union + +import torch +import torch.nn as nn +import torch.overrides +from torch.nn.modules.module import _addindent +from torch.package import Importer, PackageExporter, PackageImporter, sys_importer + +from ._compatibility import compatibility +from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode + +__all__ = [ + "reduce_graph_module", + "reduce_package_graph_module", + "reduce_deploy_graph_module", + "GraphModule", +] + +_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" + +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache +# and then tools like TorchScript will be able to get source info. +class _EvalCacheLoader: + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: Dict[str, Any], co_fields=None): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + if co_fields: + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy["__file__"] = key + globals_copy["__name__"] = key + globals_copy["__loader__"] = self + linecache.lazycache(key, globals_copy) + + return key + + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f".{self.next_id}" + self.next_id += 1 + return key + + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: Dict[str, Any], co_fields=None): + key = _loader.cache(src, globals, co_fields) + exec(compile(src, key, "exec"), globals) + + +def _forward_from_src(src: str, globals: Dict[str, Any], co_fields=None): + return _method_from_src( + method_name="forward", src=src, globals=globals, co_fields=co_fields + ) + + +def _method_from_src( + method_name: str, src: str, globals: Dict[str, Any], co_fields=None +) -> Callable: + # avoid mutating the passed in dict + globals_copy = globals.copy() + _exec_with_source(src, globals_copy, co_fields) + fn = globals_copy[method_name] + del globals_copy[method_name] + return fn + + +def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: + if name in _custom_builtins: + return _custom_builtins[name].import_str + if _is_from_torch(name): + return "import torch" + module_name, attr_name = importer.get_name(obj) + return f"from {module_name} import {attr_name} as {name}" + + +def _format_import_block(globals: Dict[str, Any], importer: Importer): + import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()} + # Sort the imports so we have a stable import block that allows us to + # hash the graph module and get a consistent key for use in a cache. + return "\n".join(sorted(import_strs)) + + +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + fn_src = body.get("_code") or body["code"] + forward = _forward_from_src(import_block + fn_src, {}) + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( + importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str +) -> torch.nn.Module: + forward = importer.import_module(generated_module_name).forward + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: Dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + +# We create a dummy class here because symbolic_trace pulls the forward() +# function off of the class, rather than the instance. This class is used +# in _deserialize_graph_module() below. +class _CodeOnlyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__ = body + + +def _deserialize_graph_module(forward, body: Dict[Any, Any], graph_module_cls=None) -> torch.nn.Module: + """ + Deserialize a GraphModule given the dictionary of the original module, + using the code to reconstruct the graph. We delete the actual graph before + saving the dictionary so that changes to the in-memory graph format do not + get serialized. + """ + + # Try to retrieve the forward source in a backward-compatible way + _CodeOnlyModule.forward = forward + + tracer_cls = body.get("_tracer_cls") + if tracer_cls is None: + from ._symbolic_trace import Tracer + + tracer_cls = Tracer + + graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") + + # This is a workaround for a mypy linter issue related to + # passing base class as an argument - https://github.com/python/mypy/issues/5865. + cls_tracer: Any = tracer_cls + + class KeepModules(cls_tracer): + # we shouldn't trace into any of the submodules, + # because they were not traced in the original GraphModule + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: + return True + + com = _CodeOnlyModule(body) + + tracer_extras = body.get("_tracer_extras", {}) + graph = KeepModules().trace(com, **tracer_extras) + + # Manually set Tracer class on the reconstructed Graph, to avoid + # referencing the private local subclass KeepModules. + graph._tracer_cls = tracer_cls + from ._lazy_graph_module import _make_graph_module + gm = _make_graph_module(com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls) + + # The GraphModule constructor only retains attributes referenced by the graph. + # In this case, our goal is return a GraphModule as close to identical as the one + # put into the package. If any additional attributes were present in body, + # we should keep them. + for k, v in body.items(): + if not hasattr(gm, k): + setattr(gm, k, v) + return gm + + +# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' +# This installs empty Modules where none exist yet if they are subpaths of target +def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we install a parent, we no longer need to copy the children + # since all the needed properties will already be present + return + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(from_obj, torch.Tensor) and not isinstance( + from_obj, torch.nn.Parameter + ): + to_module.register_buffer(field, from_obj) + else: + setattr(to_module, field, from_obj) + + +def _print_readable( + module, + module_name, + print_output=True, + include_stride=False, + include_device=False, + colored=False, +): + graph = module.graph + assert graph is not None and isinstance(graph, torch.fx.Graph), "print_readable must be used on a module with a graph" + + verbose_python_code = graph.python_code( + root_module="self", + verbose=True, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _print_readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +class _WrappedCall: + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = torch._dynamo.disable(traceback.format_exc)() + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) # type: ignore[arg-type] + if "eval_with_key" in topmost_framesummary.filename: + print( + _WrappedCall._generate_error_message(topmost_framesummary), + file=sys.stderr, + ) + raise e.with_traceback(None) # noqa: B904 + else: + raise e + +@compatibility(is_backward_compatible=True) +class GraphModule(torch.nn.Module): + """ + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __new__(cls: "Type[GraphModule]", *args, **kwargs): + # each instance of a graph module needs its own forward method + # so create a new singleton class for each instance. + # it is a subclass of the user-defined class, the only difference + # is an extra layer to install the forward method + + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split(".")[-1] + if c != "GraphModuleImpl": + cls = t + break + + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] + pass + + return super().__new__(GraphModuleImpl) + + @compatibility(is_backward_compatible=True) + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): + """ + Construct a GraphModule. + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + + class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all + error messages will report as originating from ``GraphModule``. It may be helpful to set this + to ``root``'s original name or a name that makes sense within the context of your transform. + """ + super().__init__() + self.__class__.__name__ = class_name + if isinstance(root, torch.nn.Module): + if hasattr(root, "training"): + self.training = root.training + + # When we pickle/unpickle graph module, we don't want to drop any module or attributes. + if isinstance(root, _CodeOnlyModule): + for k, _ in root.named_children(): + _copy_attr(root, self, k) + + for k, _ in root.named_buffers(): + _copy_attr(root, self, k) + + for k, _ in root.named_parameters(): + _copy_attr(root, self, k) + + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + elif isinstance(root, dict): + targets_to_copy = [] + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + if node.target not in root: + raise RuntimeError( + "Node " + + str(node) + + " referenced target " + + node.target + + " but that target was not provided in ``root``!" + ) + targets_to_copy.append(node.target) + # Sort targets in ascending order of the # of atoms. + # This will ensure that less deeply nested attributes are assigned + # before more deeply nested attributes. For example, foo.bar + # will be assigned before foo.bar.baz. Otherwise, we might assign + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` + targets_to_copy.sort(key=lambda t: t.count(".")) + for target_to_copy in targets_to_copy: + _assign_attr(root[target_to_copy], self, target_to_copy) + else: + raise RuntimeError("Unsupported type " + str(root) + " passed for root!") + + self.graph = graph + + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + self._tracer_cls = None + if ( + self.graph._tracer_cls + and "" not in self.graph._tracer_cls.__qualname__ + ): + self._tracer_cls = self.graph._tracer_cls + + self._tracer_extras = {} + if self.graph._tracer_extras: + self._tracer_extras = self.graph._tracer_extras + + # Dictionary to store metadata + self.meta: Dict[str, Any] = {} + self._replace_hook = None + self._create_node_hooks: List[Callable] = [] + self._erase_node_hooks: List[Callable] = [] + + # TorchScript breaks trying to compile the graph setter because of the + # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 + # + # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway + __jit_unused_properties__ = ["graph"] + + @property + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ + return self._graph + + @graph.setter + def graph(self, g: Graph) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" + self._graph = g + g.owning_module = self + self.recompile() + + @compatibility(is_backward_compatible=False) + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / "state_dict.pt") + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + # weights_only=False as this is legacy code that saves the model + module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" + model_str += f"{tab*2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + + model_str += ( + f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + ) + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / "module.py" + module_file.write_text(model_str) + + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + + @compatibility(is_backward_compatible=True) + def add_submodule(self, target: str, m: torch.nn.Module) -> bool: + """ + Adds the given submodule to ``self``. + + This installs empty Modules where none exist yet if they are + subpaths of ``target``. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + m: The submodule itself; the actual object we want to + install in the current Module + + Return: + bool: Whether or not the submodule could be inserted. For + this method to return True, each object in the chain + denoted by ``target`` must either a) not exist yet, + or b) reference an ``nn.Module`` (not a parameter or + other attribute) + """ + *prefix, field = target.split(".") + mod: torch.nn.Module = self + + for item in prefix: + + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, m) + return True + + @compatibility(is_backward_compatible=True) + def delete_submodule(self, target: str) -> bool: + """ + Deletes the given submodule from ``self``. + + The module will not be deleted if ``target`` is not a valid + target. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + + Returns: + bool: Whether or not the target string referenced a + submodule we want to delete. A return value of ``False`` + means that the ``target`` was not a valid reference to + a submodule. + """ + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + mod: torch.nn.Module = self + + # Get the parent module + for item in path: + + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + return False + + if not hasattr(mod, target_submod): + return False + + if not isinstance(getattr(mod, target_submod), torch.nn.Module): + return False + + delattr(mod, target_submod) + return True + + @compatibility(is_backward_compatible=True) + def delete_all_unused_submodules(self) -> None: + """ + Deletes all unused submodules from ``self``. + + A Module is considered "used" if any one of the following is + true: + 1. It has children that are used + 2. Its forward is called directly via a ``call_module`` node + 3. It has a non-Module attribute that is used from a + ``get_attr`` node + + This method can be called to clean up an ``nn.Module`` without + manually calling ``delete_submodule`` on each unused submodule. + """ + used: List[str] = [] + + for node in self.graph.nodes: + + if node.op == "call_module" or node.op == "get_attr": + + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # If we're looking at multiple parts of a path, join + # join them with a dot. Otherwise, return that single + # element without doing anything to it. + def join_fn(x: str, y: str) -> str: + return ".".join([x, y] if y else [x]) + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.extend(itertools.accumulate(fullpath, join_fn)) + + # For a `call_module` node, also register all recursive submodules + # as used + if node.op == "call_module": + try: + submod = self.get_submodule(node.target) + + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.append(".".join([node.target, submod_name])) + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + to_delete = [name for name, _ in self.named_modules() if name not in used] + + for name in to_delete: + self.delete_submodule(name) + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, "_code"): + raise RuntimeError( + "Code has not been generated! Please report a bug to PyTorch" + ) + return self._code + + @compatibility(is_backward_compatible=True) + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module="self") + self._code = python_code.src + self._lineno_map = python_code._lineno_map + + cls = type(self) + co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped # type: ignore[method-assign] + + return python_code + + # Passing Tracer as argument allows subclasses extending fx.GraphModule + # define their own Tracer (extending fx.Tracer). + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) + + def __reduce_package__(self, exporter: PackageExporter): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ + dict_without_graph = self.__dict__.copy() + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _deepcopy_init(self): + return GraphModule.__init__ + + # because __reduce__ is defined for serialization, + # we need to define deepcopy otherwise it will call __reduce__ + # and cause symbolic tracing to occur every time we try to copy the object + def __deepcopy__(self, memo): + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) + # hooks are lost during `GraphModule.__init__`, so we need to copy over + # them explicitly, note right now we are only copying state_dict related + # hooks, to reduce bc-related issues, we can copy forward/backward related + # hooks in the future as well if needed + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hook", + "_create_node_hooks", + "_erase_node_hooks" + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + return res + + def __copy__(self): + from ._lazy_graph_module import _make_graph_module + res = _make_graph_module(self, self.graph) + res.meta = getattr(self, "meta", {}) + return res + + @compatibility(is_backward_compatible=False) + def print_readable(self, print_output=True, include_stride=False, include_device=False, colored=False): + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + return _print_readable( + self, + self._get_name(), + print_output, + include_stride, + include_device, + colored, + ) + + def __str__(self) -> str: + orig_str = super().__str__() + print_readable_reminder = ( + "# To see more debug info, please use `graph_module.print_readable()`" + ) + return "\n".join([orig_str, self._code, print_readable_reminder]) + + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + prev, self._replace_hook = self._replace_hook, f + try: + yield + finally: + self._replace_hook = prev + + def _register_create_node_hook(self, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.append(f) + + def _unregister_create_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we create a node. + This function will unregister that callable so it is no longer invoked on node creation. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.remove(f) + + def _register_erase_node_hook(self, f): + """ + Takes a callable which will be called after we erase a node. The + callable takes the node that is being erased as input and returns None. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.append(f) + + def _unregister_erase_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we erase a node. + This function will unregister that callable so it is no longer invoked on node erasure. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.remove(f) + +# workarounds for issues in __torch_function__ + +# WAR for __torch_function__ not handling tensor lists, +# fix is in https://github.com/pytorch/pytorch/pull/34725 +# orig_cat = torch.cat +# def patched_cat(*args, **kwargs): +# tensors = args[0] +# for t in tensors: +# if isinstance(t, Proxy): +# return t.__torch_function__(patched_cat, (), args, kwargs) +# return orig_cat(*args, **kwargs) +# patched_cat.__module__ = 'torch' +# patched_cat.__name__ = 'cat' +# torch.cat = patched_cat diff --git a/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py b/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff29cba474dcb613cd529b881193e20e1325856 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/immutable_collections.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, Iterable, List, Tuple + +from torch.utils._pytree import ( + _dict_flatten, + _dict_flatten_with_keys, + _dict_unflatten, + _list_flatten, + _list_flatten_with_keys, + _list_unflatten, + Context, + register_pytree_node, +) + +from ._compatibility import compatibility + + +__all__ = ["immutable_list", "immutable_dict"] + +_help_mutation = """\ +If you are attempting to modify the kwargs or args of a torch.fx.Node object, +instead create a new copy of it and assign the copy to the node: + new_args = ... # copy and mutate args + node.args = new_args +""" + + +def _no_mutation(self, *args, **kwargs): + raise NotImplementedError( + f"'{type(self).__name__}' object does not support mutation. {_help_mutation}", + ) + + +def _create_immutable_container(base, mutable_functions): + container = type("immutable_" + base.__name__, (base,), {}) + for attr in mutable_functions: + setattr(container, attr, _no_mutation) + return container + + +immutable_list = _create_immutable_container( + list, + ( + "__delitem__", + "__iadd__", + "__imul__", + "__setitem__", + "append", + "clear", + "extend", + "insert", + "pop", + "remove", + "reverse", + "sort", + ), +) +immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) +immutable_list.__hash__ = lambda self: hash(tuple(self)) + +compatibility(is_backward_compatible=True)(immutable_list) + +immutable_dict = _create_immutable_container( + dict, + ( + "__delitem__", + "__ior__", + "__setitem__", + "clear", + "pop", + "popitem", + "setdefault", + "update", + ), +) +immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) +immutable_dict.__hash__ = lambda self: hash(tuple(self.items())) +compatibility(is_backward_compatible=True)(immutable_dict) + + +# Register immutable collections for PyTree operations +def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[Any], + context: Context, +) -> Dict[Any, Any]: + return immutable_dict(_dict_unflatten(values, context)) + + +def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return _list_flatten(d) + + +def _immutable_list_unflatten( + values: Iterable[Any], + context: Context, +) -> List[Any]: + return immutable_list(_list_unflatten(values, context)) + + +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +register_pytree_node( + immutable_list, + _immutable_list_flatten, + _immutable_list_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_list", + flatten_with_keys_fn=_list_flatten_with_keys, +) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py b/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..c75407583137d30e377a6b4e8c734858b7b5063b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/interpreter.py @@ -0,0 +1,520 @@ +# mypy: allow-untyped-defs +from .graph_module import GraphModule +from ._lazy_graph_module import _make_graph_module +from .graph import Graph +from .node import Argument, Node, Target, map_arg, map_aggregate +from .proxy import Proxy +from ._symbolic_trace import Tracer +from ._compatibility import compatibility +from . import config +import torch.fx.traceback as fx_traceback +import torch +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +import inspect +from contextlib import contextmanager +from torch.hub import tqdm + +__all__ = ['Interpreter', 'Transformer'] + +@compatibility(is_backward_compatible=True) +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overrideable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : Target, + args : Tuple, kwargs : Dict) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_close(result, torch.neg(input).sigmoid()) + + Args: + module (torch.nn.Module): The module to be executed + garbage_collect_values (bool): Whether to delete values after their last + use within the Module's execution. This ensures optimal memory usage during + execution. This can be disabled to, for example, examine all of the intermediate + values in the execution by looking at the ``Interpreter.env`` attribute. + graph (Optional[Graph]): If passed, the interpreter will execute this + graph instead of `module.graph`, using the provided `module` + argument to satisfy any requests for state. + """ + @compatibility(is_backward_compatible=True) + def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): + self.module = module + self.submodules = dict(self.module.named_modules()) + if graph is not None: + self.graph = graph + else: + self.graph = self.module.graph + self.env : Dict[Node, Any] = {} + self.name = "Interpreter" + self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True + + if self.garbage_collect_values: + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use : Dict[Node, Node] = {} + self.user_to_last_uses : Dict[Node, List[Node]] = {} + + def register_last_uses(n : Node, user : Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + self.user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + @compatibility(is_backward_compatible=True) + def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env is not None else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.graph.process_inputs(*args) + self.args_iter : Iterator[Any] = iter(args) + pbar = tqdm(total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) + + for node in self.graph.nodes: + pbar.update(1) + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + try: + self.env[node] = self.run_node(node) + except Exception as e: + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == 'output': + output_val = self.env[node] + return self.graph.process_outputs(output_val) if enable_io_processing else output_val + + @compatibility(is_backward_compatible=True) + def boxed_run(self, args_list): + """ + Run `module` via interpretation and return the result. This uses the "boxed" + calling convention, where you pass a list of arguments, which will be cleared + by the interpreter. This ensures that input tensors are promptly deallocated. + """ + args_iter = iter(args_list) + env = {} + for n in self.graph.nodes: + if n.op == "placeholder": + env[n] = next(args_iter) + args_list.clear() + return self.run(initial_env=env) + + @contextmanager + def _set_current_node(self, node): + with fx_traceback.set_current_meta(node): + yield + + @compatibility(is_backward_compatible=True) + def run_node(self, n : Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + with self._set_current_node(n): + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith('*'): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + try: + return next(self.args_iter) + except StopIteration as si: + if len(args) > 0: + return args[0] + else: + raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target : str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split('.') + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @compatibility(is_backward_compatible=True) + def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + @compatibility(is_backward_compatible=True) + def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + def load_arg(n_arg : Node) -> Any: + if n_arg not in self.env: + raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' + f'to diagnose such issues') + return self.env[n_arg] + return map_arg(args, load_arg) + +@compatibility(is_backward_compatible=True) +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(n) + + def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + if target == 'neg': + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(n) + + def fn(x): + return torch.sigmoid(x).neg() + + gm = torch.fx.symbolic_trace(fn) + + transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + self.new_graph.set_codegen(module.graph._codegen) + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] + + def is_leaf_module(self, _, __) -> bool: + return True + + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + @compatibility(is_backward_compatible=True) + def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + default_value = next(iter(args)) if args else inspect.Signature.empty + return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) + + @compatibility(is_backward_compatible=True) + def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return self.tracer.create_proxy("get_attr", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that the leaf module policy from `self.tracer` is respected. + assert isinstance(target, str) + submod = self.fetch_attr(target) + return self.tracer.call_module(submod, submod.forward, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + # Override so that functions that were wrapped are still wrapped. + return self.tracer.create_proxy('call_function', target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + with fx_traceback.preserve_node_meta(): + result = super().run(enable_io_processing=False) + if result is not None: + def strip_proxy(a : Union[Argument, Proxy]) -> Any: + return a.node if isinstance(a, Proxy) else a + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) + # also preserve the metadata from the old output node, if it exists + old_output_node = list(self.graph.nodes)[-1] + assert old_output_node.op == "output" + for k, v in old_output_node.meta.items(): + new_output_node.meta[k] = v + + + return _make_graph_module(self.module, self.new_graph) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/node.py b/.venv/lib/python3.11/site-packages/torch/fx/node.py new file mode 100644 index 0000000000000000000000000000000000000000..f84b23e29ddf855f799b3de6ce1bd4ccf3da3dd4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/node.py @@ -0,0 +1,788 @@ +# Nodes represent a definition of a value in our graph of operators. +from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set +from ._compatibility import compatibility +from .immutable_collections import immutable_dict, immutable_list +import torch +import builtins +import types +import inspect +import warnings +from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair +from .._ops import ops as _ops +from torch._C import _NodeBase + +if TYPE_CHECKING: + from .graph import Graph + +__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] + +BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, + torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, + torch.SymInt, torch.SymBool, torch.SymFloat] +base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] + +Target = Union[Callable[..., Any], str] + +Argument = Optional[Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + 'Node', + BaseArgumentTypes +]] + +_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) + +_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, +} + +# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, +# or add logic to correctly mark all inplace ops as side effectful. +_side_effectful_functions: Set[Callable] = { + torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, + _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, + _ops.profiler._record_function_enter, + _ops.profiler._record_function_enter_new, + _ops.profiler._record_function_exit, + _ops.inductor.accumulate_grad_.default, +} | _side_effectful_need_to_be_preserved_pre_dispatch +if hasattr(_ops.inductor, "resize_storage_bytes_"): + _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) + + +@compatibility(is_backward_compatible=False) +def has_side_effect(fn: Callable) -> Callable: + _side_effectful_functions.add(fn) + return fn + + +# this is fixed on master, WAR for 1.5 +def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f'cannot find module for {orig_method}') + +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj: object) -> str: + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + if isinstance(obj, type): + if obj.__module__ == 'builtins': + return obj.__qualname__ + return f'{obj.__module__}.{obj.__qualname__}' + if obj is ...: + return '...' + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + +def _get_qualified_name(func: Callable[..., Any]) -> str: + # things like getattr just appear in builtins + if getattr(builtins, func.__name__, None) is func: + return func.__name__ + # torch.Tensor.{fn} + if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) + and func is getattr(torch.Tensor, func.__name__, None)): + return f"torch.Tensor.{func.__name__}" + name = func.__name__ + if name == "": + # For lambdas, try to get their defining name in the module + try: + name = inspect.getsource(func).split("=")[0].strip() + except Exception as e: + raise RuntimeError("Unable to represent lambda") from e + module = _find_module_of_method(func) + module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + # Fixup segment_reduce mismatch + if module == "torch" and name == "segment_reduce": + name = "_" + name + return f'{module}.{name}' + +def _format_arg(arg: object, max_list_len: float = float('inf')) -> str: + if hasattr(arg, '_custom_fx_repr_fn'): + return arg._custom_fx_repr_fn() + elif isinstance(arg, list): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + return f'[{items}{maybe_len}]' + elif isinstance(arg, tuple): + items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) + maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' + maybe_comma = ',' if len(arg) == 1 else '' + return f'({items}{maybe_comma}{maybe_len})' + elif isinstance(arg, dict): + items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) + return f'{{{items_str}}}' + + if isinstance(arg, Node): + return '%' + str(arg) + else: + return str(arg) + +@compatibility(is_backward_compatible=True) +class Node(_NodeBase): + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + _args: Tuple['Argument', ...] + _kwargs: Dict[str, 'Argument'] + + @compatibility(is_backward_compatible=True) + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], + return_type : Optional[Any] = None) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ + super().__init__() + self.graph = graph + self.name = name # unique name of value being created + assert op in _legal_ops + self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + if op == 'call_function': + if not callable(target): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a Callable is expected') + else: + if not isinstance(target, str): + raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' + 'but a str is expected') + self.target = target # for method/module/function, the name of the method/module/function/attr + # being invoked, e.g add, layer1, or torch.add + + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} + self.__update_args_kwargs(args, kwargs) + + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # would appear once here, but represents two uses. + # + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + self.users : Dict[Node, None] = {} + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return node, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + self.type : Optional[Any] = return_type + self._sort_key: Any = () + + # If set, use this fn to print this node + self._repr_fn : Optional[Callable[[Node], str]] = None + + # Dictionary to store metadata passes need to do their + # transformations. This metadata is preserved across node copies + self.meta : Dict[str, Any] = {} + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_erased"] = self._erased + state["_prev"] = self._prev + state["_next"] = self._next + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + _erased = state.pop("_erased") + _prev = state.pop("_prev") + _next = state.pop("_next") + self.__dict__.update(state) + self._erased = _erased + self._prev = _prev + self._next = _next + + @property + def next(self) -> 'Node': + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> 'Node': + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + @compatibility(is_backward_compatible=True) + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + # compute x._sort_key + psk = x._prev._sort_key + nsk = x._next._sort_key + if len(psk) > len(nsk): + idx: int + *prefix, idx = psk[:len(nsk) + 1] + x._sort_key = (*prefix, idx + 1) + elif len(psk) < len(nsk): + *prefix, idx = nsk[:len(psk) + 1] + x._sort_key = (*prefix, idx - 1) + else: # same length, increase length by 1 + x._sort_key = (*psk, 0) + + def __gt__(self, other: 'Node') -> bool: + return self._sort_key > other._sort_key + + def __lt__(self, other: 'Node') -> bool: + return self._sort_key < other._sort_key + + def __ge__(self, other: 'Node') -> bool: + return self > other or self == other + + def __le__(self, other: 'Node') -> bool: + return self < other or self == other + + @compatibility(is_backward_compatible=True) + def append(self, x: 'Node') -> None: + """ + Insert ``x`` after this node in the list of nodes in the graph. + Equivalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + self._next.prepend(x) + + def _remove_from_list(self) -> None: + p, n = self._prev, self._next + p._next, n._prev = n, p + + @property + def args(self) -> Tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a : Tuple[Argument, ...]) -> None: + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self.__update_args_kwargs(a, self._kwargs) + + @property + def kwargs(self) -> Dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k : Dict[str, Argument]) -> None: + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `__update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self.__update_args_kwargs(self._args, k) + + @property + def all_input_nodes(self) -> List['Node']: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + @compatibility(is_backward_compatible=True) + def update_arg(self, idx : int, arg : Argument) -> None: + """ + Update an existing positional argument to contain the new value + ``arg``. After calling, ``self.args[idx] == arg``. + + Args: + + idx (int): The index into ``self.args`` of the element to update + arg (Argument): The new argument value to write into ``args`` + """ + args = list(self.args) + args[idx] = arg + self.args = tuple(args) + + @compatibility(is_backward_compatible=True) + def insert_arg(self, idx : int, arg : Argument) -> None: + """ + Insert an positional argument to the argument list with given index. + + Args: + + idx (int): The index of the element in ``self.args`` to be inserted before. + arg (Argument): The new argument value to insert into ``args`` + """ + assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" + args_left = self.args[:idx] + args_right = self.args[idx:] + + self._args = args_left + (arg,) + args_right + + _new_input_nodes: Dict[Node, None] = {} + map_arg(arg, _new_input_nodes.setdefault) + + for new_use in _new_input_nodes.keys(): + if new_use not in self._input_nodes: + self._input_nodes.setdefault(new_use) + new_use.users.setdefault(self) + + @compatibility(is_backward_compatible=True) + def update_kwarg(self, key : str, arg : Argument) -> None: + """ + Update an existing keyword argument to contain the new value + ``arg``. After calling, ``self.kwargs[key] == arg``. + + Args: + + key (str): The key in ``self.kwargs`` of the element to update + arg (Argument): The new argument value to write into ``kwargs`` + """ + self.kwargs = {**self.kwargs, key: arg} + + @property + def stack_trace(self) -> Optional[str]: + """ + Return the Python stack trace that was recorded during tracing, if any. + When traced with fx.Tracer, this property is usually populated by + `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, + set `record_stack_traces = True` on the `Tracer` instance. + When traced with dynamo, this property will be populated by default by + `OutputGraph.create_proxy`. + + stack_trace would have the innermost frame at the end of the string. + """ + return self.meta.get("stack_trace", None) + + @stack_trace.setter + def stack_trace(self, trace : Optional[str]) -> None: + self.meta["stack_trace"] = trace + + def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: + """ + This API is internal. Do *not* call it directly. + """ + def update_users_and_input_nodes(n: Any) -> Any: + if isinstance(n, Node): + self._input_nodes.setdefault(n) + n.users.setdefault(self) + return n + + # Clear prior users and input_nodes + for old_use in self._input_nodes.keys(): + old_use.users.pop(self) + self._input_nodes = {} + + # We do three things in a single pass of the args + # - Normalize list->immutable_list, dict->immutable_dict, etc + # - Populate self._input_nodes + # - Populate arg.users[self] for each arg + self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment] + self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment] + + def __repr__(self) -> str: + if self._repr_fn: + return self._repr_fn(self) + return self.name + + def _pretty_print_target(self, target: object) -> str: + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualified name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, '__module__'): + name = getattr(target, '__name__', None) + if name is None: + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` + return _get_qualified_name(target) # type: ignore[arg-type] + if target.__module__ == 'builtins': + return f'builtins.{name}' + elif target.__module__ == '_operator': + return f'operator.{name}' + return _get_qualified_name(target) # type: ignore[arg-type] + + @compatibility(is_backward_compatible=True) + def format_node(self, + placeholder_names: Optional[List[str]] = None, + maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: + """ + Return a descriptive string representation of ``self``. + + This method can be used with no arguments as a debugging + utility. + + This function is also used internally in the ``__str__`` method + of ``Graph``. Together, the strings in ``placeholder_names`` + and ``maybe_return_typename`` make up the signature of the + autogenerated ``forward`` function in this Graph's surrounding + GraphModule. ``placeholder_names`` and ``maybe_return_typename`` + should not be used otherwise. + + Args: + placeholder_names: A list that will store formatted strings + representing the placeholders in the generated + ``forward`` function. Internal use only. + maybe_return_typename: A single-element list that will store + a formatted string representing the output of the + generated ``forward`` function. Internal use only. + + Returns: + str: If 1) we're using ``format_node`` as an internal helper + in the ``__str__`` method of ``Graph``, and 2) ``self`` + is a placeholder Node, return ``None``. Otherwise, + return a descriptive string representation of the + current Node. + """ + if self.op == 'placeholder': + assert isinstance(self.target, str) + arg_str = self.target + arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' + if placeholder_names: + placeholder_names.append(arg_str) + return None + maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' + default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' + elif self.op == 'get_attr': + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}]' + elif self.op == 'output': + if self.type and maybe_return_typename: + maybe_return_typename[0] = f' -> {_type_repr(self.type)}' + return f'return {self.args[0]}' + else: + maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' + return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ + f'{self.op}[target={self._pretty_print_target(self.target)}](' \ + f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + + @compatibility(is_backward_compatible=True) + def replace_all_uses_with(self, + replace_with: 'Node', + delete_user_cb: Callable[['Node'], bool] = lambda user: True, + *, + propagate_meta: bool = False + ) -> List['Node']: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + delete_user_cb (Callable): Callback that is called to determine + whether a given user of the self node should be removed. + propagate_meta (bool): Whether or not to copy all properties + on the .meta field of the original node onto the replacement node. + For safety, this is only valid to do if the replacement node + doesn't already have an existing .meta field. + + Returns: + + The list of Nodes on which this change was made. + """ + if propagate_meta: + assert len(replace_with.meta) == 0, \ + 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ + 'but replace_with already has .meta keys' + for k, v in self.meta.items(): + replace_with.meta[k] = v + to_process = list(self.users) + skipped = [] + m = self.graph.owning_module + for use_node in to_process: + if not delete_user_cb(use_node): + skipped.append(use_node) + continue + + def maybe_replace_node(n : Node) -> Node: + if n == self: + return replace_with + else: + return n + + if getattr(m, "_replace_hook", None): + m._replace_hook(old=self, new=replace_with.name, user=use_node) + + new_args = map_arg(use_node.args, maybe_replace_node) + new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node.__update_args_kwargs(new_args, new_kwargs) + + assert len(self.users) - len(skipped) == 0 + return [n for n in to_process if n not in skipped] + + @compatibility(is_backward_compatible=False) + def is_impure(self) -> bool: + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Returns: + + bool: If the op is impure or not. + """ + if self.op in {"placeholder", "output"}: + return True + + # Check if an impure function based on schema. + if self.op == "call_function": + schema = getattr(self.target, "_schema", None) + schema_mutable = schema is not None and schema.is_mutable + return schema_mutable or self.target in _side_effectful_functions + + # Check if an impure module. + if self.op == "call_module": + assert ( + self.graph.owning_module is not None + ), "self.graph.owning_module not set for purity check" + target_mod = self.graph.owning_module.get_submodule(self.target) + assert ( + target_mod is not None + ), f"Did not find expected submodule target {self.target}" + return getattr(target_mod, "_is_impure", False) + + return False + + @compatibility(is_backward_compatible=False) + def normalized_arguments( + self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and return exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. + Also populates default values. Does not support positional-only + parameters or varargs parameters. + + Supports module calls. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + root (torch.nn.Module): Module upon which to resolve module targets. + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns NamedTuple ArgsKwargsPair, or `None` if not successful. + """ + if self.op == 'call_function': + assert callable(self.target) + return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] + elif self.op == 'call_module': + assert isinstance(self.target, str) + return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] + + return None + + @compatibility(is_backward_compatible=True) + def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: + """ + Loop through input nodes of ``self``, and replace all instances of + ``old_input`` with ``new_input``. + + Args: + + old_input (Node): The old input node to be replaced. + new_input (Node): The new input node to replace ``old_input``. + """ + def maybe_replace_node(n : Node) -> Node: + return new_input if n == old_input else n + + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + m._replace_hook(old=old_input, new=new_input.name, user=self) + + new_args = map_arg(self.args, maybe_replace_node) + new_kwargs = map_arg(self.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + self.__update_args_kwargs(new_args, new_kwargs) + + def _rename(self, candidate: str) -> None: + if candidate == self.name: + return + name = self.graph._graph_namespace.create_name(candidate, None) + self.name = name + self.graph._graph_namespace._rename_object(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name' and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hook", None): + assert isinstance(value, str) + for user in self.users: + m._replace_hook(old=self, new=value, user=user) + update = False + if ( + hasattr(self, name) and + hasattr(self.graph, "_find_nodes_lookup_table") and + self in self.graph._find_nodes_lookup_table + ): + update = True + self.graph._find_nodes_lookup_table.remove(self) + object.__setattr__(self, name, value) + if update: + self.graph._find_nodes_lookup_table.insert(self) + +@compatibility(is_backward_compatible=True) +def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" + return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + +@compatibility(is_backward_compatible=True) +def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: + """ + Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. + """ + if isinstance(a, tuple): + t = tuple([map_aggregate(elem, fn) for elem in a]) + # Support NamedTuple (if it has `_fields`) by repacking into original type. + return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] + elif isinstance(a, list): + return immutable_list([map_aggregate(elem, fn) for elem in a]) + elif isinstance(a, dict): + rv = immutable_dict() + for k, v in a.items(): + dict.__setitem__(rv, k, map_aggregate(v, fn)) + return rv + elif isinstance(a, slice): + return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + else: + return fn(a) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py b/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5beed5285d9a208d7a28ebdeeb79e1bdf2c19e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/operator_schemas.py @@ -0,0 +1,451 @@ +# mypy: allow-untyped-defs +import torch +import inspect +import numbers +import types +import typing +import enum +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING +from torch._jit_internal import boolean_dispatched +from ._compatibility import compatibility +from torch._ops import OpOverloadPacket, OpOverload + +if TYPE_CHECKING: + from .node import Argument + +__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", + "type_matches", "normalize_function", "normalize_module"] + +@compatibility(is_backward_compatible=False) +class ArgsKwargsPair(NamedTuple): + """ + Simple named tuple for wrapping args/kwargs pairs. + """ + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + +_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} + +def _nonzero_schemas(): + signatures = [] + + def nonzero(self): + pass + signatures.append(inspect.signature(nonzero)) + + def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] + pass + signatures.append(inspect.signature(nonzero)) + + return signatures + +_manual_overrides[torch.nonzero] = _nonzero_schemas() + +class _FakeGlobalNamespace: + def __getattr__(self, name): + if name == 'torch': + return torch + raise RuntimeError('Expected a torch namespace lookup') + +_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, + 'number' : numbers.Number, 'Future' : torch.jit.Future, + 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, + '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), + 'Storage': torch.UntypedStorage, + 't': typing.TypeVar('t')} +for k in dir(typing): + _type_eval_globals[k] = getattr(typing, k) + +def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: + """ + Convert a TorchScript type to a Python type (including subtypes) via + eval'ing the annotation_str. _type_eval_globals sets up expressions + like "List" and "Future" to map to actual types (typing.List and jit.Future) + """ + return eval(ts_type.annotation_str, _type_eval_globals) + +def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + from inspect import Parameter + parameters : List[Parameter] = [] + for arg in ts_schema.arguments: + arg_type = _torchscript_type_to_python_type(arg.type) + default = arg.default_value if arg.has_default_value() else Parameter.empty + # TODO: Figure out if this is safe. It seems like when generating the type signatures for + # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor + # argument name. Downstream, if someone converts that positional argument to a keyword + # argument, the name mismatch will break things, so here we're going to normalize the + # name to "input" + name = arg.name if arg.name != 'self' else 'input' + kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD + # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument + if name == "from": + assert kind == Parameter.POSITIONAL_OR_KEYWORD + # ParameterKind type is internal implementation detail to inspec package + # which makes it hard to do type annotation + kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] + # This renders all previous arguments to positional only + for idx, p in enumerate(parameters): + assert p.kind == Parameter.POSITIONAL_OR_KEYWORD + parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) + parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) + return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] + if len(return_types) == 0: + return_type = None + elif len(return_types) == 1: + return_type = return_types[0] + else: + return_type = tuple(return_types) + + return inspect.Signature(parameters, return_annotation=return_type) + +_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} + +def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + # Cached as it's called in the hot path of FakeTensor dispatch + cache_key = ts_schema.name, ts_schema.overload_name + cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) + if cache_val is not None: + return cache_val + + res = _torchscript_schema_to_signature_impl(ts_schema) + _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res + return res + +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError as e: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' + f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' + f'are not supported') + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): + """ + Given an operator on the `torch` namespace, return a list of `inspect.Signature` + objects corresponding to the overloads of that op.. May return `None` if a signature + could not be retrieved. + + Args: + op (Callable): An operator on the `torch` namespace to look up a signature for + + Returns: + Optional[List[inspect.Signature]]: A list of signatures for the overloads of this + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature + """ + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None + + aten_fn = torch.jit._builtins._find_builtin(op) + + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) + + signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] + return (signatures, schemas) if return_schemas else signatures + +@compatibility(is_backward_compatible=False) +def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ + try: + if isinstance(x, (list, tuple)): + # todo(chilli): Figure out the right way for mypy to handle this + if isinstance(x, list): + def ret_type(x): + return List[x] # type: ignore[valid-type] + else: + def ret_type(x): + return Tuple[x, ...] + if len(x) == 0: + return ret_type(Any) + base_type = x[0] + for t in x: + if issubclass(t, base_type): + continue + elif issubclass(base_type, t): + base_type = t + else: + return ret_type(Any) + return ret_type(base_type) + except Exception as e: + # We tried to create a type hint for list but failed. + warnings.warn(f"We were not able to successfully create type hint from the type {x}") + return x + +@compatibility(is_backward_compatible=False) +def type_matches(signature_type : Any, argument_type : Any): + sig_origin_type = getattr(signature_type, '__origin__', signature_type) + + if signature_type is argument_type: + return True + + # Union types in signature. Given type needs to match one of the + # contained types in the Union + if sig_origin_type is typing.Union and signature_type != argument_type: + sig_contained = signature_type.__args__ + return any(type_matches(c, argument_type) for c in sig_contained) + + if signature_type is List[int] and argument_type is int: + # int can be promoted to List[int] + return True + + if getattr(signature_type, '__origin__', None) in {list, List}: + sig_el_type = signature_type.__args__[0] + if not inspect.isclass(sig_el_type): + warnings.warn( + f"Does not support nested parametric types, got {signature_type}. Please file a bug.") + return False + if getattr(argument_type, '__origin__', None) in {list, List}: + return issubclass(argument_type.__args__[0], sig_el_type) + + def is_homogeneous_tuple(t): + if getattr(t, "__origin__", None) not in {tuple, Tuple}: + return False + contained = t.__args__ + if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + return True + return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) + + # Tuple[T] is accepted for List[T] parameters + return is_homogeneous_tuple(argument_type) + + # Dtype is an int in schemas + if signature_type is int and argument_type is torch.dtype: + return True + + if signature_type is numbers.Number and argument_type in {int, float}: + return True + if inspect.isclass(argument_type) and inspect.isclass(signature_type): + return issubclass(argument_type, signature_type) + + return False + +@compatibility(is_backward_compatible=False) +def normalize_function( + target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, + kwarg_types : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch functions. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). Does not support modules. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + if kwargs is None: + kwargs = {} + new_args_and_kwargs = None + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, (OpOverloadPacket, OpOverload)) + ): + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched['if_true'], dispatched['if_false'] + if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: + return None + target_for_analysis = if_true + + assert callable(target_for_analysis) + sig = inspect.signature(inspect.unwrap(target_for_analysis)) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) + else: + assert callable(target) + torch_op_schemas = get_signature_for_torch_op(target) + matched_schemas = [] + if torch_op_schemas: + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature in torch_op_schemas: + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append(candidate_signature) + except TypeError as e: + continue + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot normalize + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, + normalize_to_only_use_kwargs) + else: + if arg_types is not None or kwarg_types is not None: + arg_types = arg_types if arg_types else cast(Tuple[Any], ()) + kwarg_types = kwarg_types if kwarg_types else {} + for candidate_signature in torch_op_schemas: + sig_matches = True + try: + bound_types = candidate_signature.bind(*arg_types, **kwarg_types) + for arg_name, arg_type in bound_types.arguments.items(): + param = candidate_signature.parameters[arg_name] + sig_matches = sig_matches and type_matches(param.annotation, arg_type) + except TypeError as e: + sig_matches = False + if sig_matches: + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, + normalize_to_only_use_kwargs) + break + else: + # Matched more than one schema. In this situation, the caller must provide the types of + # the arguments of the overload they expect. + schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) + raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' + f'the schema match was ambiguous! Please provide argument types to ' + f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') + + return new_args_and_kwargs + +@compatibility(is_backward_compatible=False) +def normalize_module( + root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch modules. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). + + Args: + root (nn.Module): root module upon which we query modules + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + try: + submod = root.get_submodule(target) + except AttributeError as e: + raise RuntimeError(f"Tried to normalize node with target {target} but root did not " + f"have that target!") from e + if hasattr(submod.__class__, '__name__'): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + sig = inspect.signature(inspect.unwrap(submod.forward)) + if kwargs is None: + kwargs = {} + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, + normalize_to_only_use_kwargs) + return new_args_and_kwargs + return None + +def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], + kwargs : Dict[str, Any], + normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: + """ + Given a call target, args, and kwargs, return the arguments normalized into + an ArgsKwargsPair, or None if the type signature is not supported by + this normalization. + + Args: + + sig (inspect.Signature): Signature object for the target + args (Tuple): Arguments that appear at the callsite for `target` + kwargs (Dict): Keyword arguments that appear at the callsite for `target` + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if + this target is not supported. + """ + + # Don't currently support positional-only + # or varargs (*args, **kwargs) signatures + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): + # Add an exception for one signature, which is common for random/uniform, i.e.: + # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None + # `from` is Python keyword and as such functions with that signature should have + # positional-only args, but at the same time they could be dispatched as kwargs + if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: + return None + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + new_kwargs : Dict[str, Any] = {} + new_args : List[Any] = [] + for i, param in enumerate(sig.parameters): + if not normalize_to_only_use_kwargs and i < len(args): + new_args.append(bound_args.arguments[param]) + else: + new_kwargs[param] = bound_args.arguments[param] + + return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/proxy.py b/.venv/lib/python3.11/site-packages/torch/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..2b86a1c609f918d63dbcddf791cddf3d72b9f924 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/proxy.py @@ -0,0 +1,609 @@ +# mypy: ignore-errors + +import enum +import dis +import copy +import sys +import torch +import inspect +import operator +import collections +import logging + +from dataclasses import is_dataclass, fields + + +from .graph import magic_methods, reflectable_magic_methods, Graph +from torch.utils._traceback import CapturedTraceback +from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable +from .node import Target, Node, Argument, base_types, map_aggregate +from ._compatibility import compatibility +from .operator_schemas import check_for_mutable_operation +import torch.fx.traceback as fx_traceback + +__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', + 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', + 'ScopeContextManager'] + + +log = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class Scope: + """ Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + class M(torch.nn.Module): + def __init__(self) -> None: + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +@compatibility(is_backward_compatible=False) +class ScopeContextManager: + """ A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return + + +_COPY_META_FIELDS = [ + "nn_module_stack", + "torch_fn", + "source_fn_stack", + "original_aten", + "recompute", + "ac_graph_id", + "from_node", + "quantization_tag", # TODO deprecated + "_numeric_debug_handle", # TODO deprecated + "custom", + "partitioner_tag" +] + + +@compatibility(is_backward_compatible=True) +class TracerBase: + graph: Graph + record_stack_traces : bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations : bool = False + # Feature flag for assert tracing + trace_asserts : bool = False + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes : bool = False + + # Name of the function to be traced. It will only be used when + # ``root`` is an instance of ``nn.Module`` + traced_func_name: str = "forward" + + # Maps the containing module's name to the operator name + scope : Scope + + # Records the module call stack + module_stack: OrderedDict[str, Tuple[str, Any]] + + # Mapping of node name to module scope + node_name_to_scope: Dict[str, Tuple[str, type]] + + @compatibility(is_backward_compatible=True) + def create_node(self, kind : str, target : Target, + args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, + type_expr : Optional[Any] = None) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ + + if kind == 'call_function' and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depreciated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + # Optionally set stack trace on the created Node for debugging purposes + if fx_traceback.has_preserved_node_meta(): + current_meta: Dict[str, Any] = fx_traceback.get_current_meta() + + stack_trace = current_meta.get("stack_trace") + if stack_trace: + node.stack_trace = stack_trace + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + for field in _COPY_META_FIELDS: + if field in current_meta: + node.meta[field] = copy.copy(current_meta[field]) + + # Here we decrement to account for the sequence_nr having + # just been incremented while tracing this lowered aten op. + new_seq_nr = torch.autograd._get_sequence_nr() - 1 + # The sequence_nr increments every time a new autograd Node + # is created. During the FWD pass we store the sequence_nr + # corresponding to the last autograd Node created on this fx + # node's meta. A single aten op can create multiple autograd + # nodes as is the case with in-place foreach ops. During the + # BWD pass we retrieve the sequence_nr stored on the current + # executing autograd Node. See NOTE [ Sequence Number ]. + if current_meta.get("in_grad_fn", 0) > 0: + new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + node.meta["seq_nr"] = new_seq_nr + + elif self.module_stack: + node.meta['nn_module_stack'] = copy.copy(self.module_stack) + + log.debug("create_node %s", node) + return node + + @compatibility(is_backward_compatible=True) + def proxy(self, node: Node) -> 'Proxy': + return Proxy(node, self) + + @compatibility(is_backward_compatible=True) + def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], + name: Optional[str] = None, type_expr : Optional[Any] = None, + proxy_factory_fn: Callable[[Node], 'Proxy'] = None): + ''' + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + ''' + + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + + if not proxy_factory_fn: + proxy = self.proxy(node) + else: + proxy = proxy_factory_fn(node) + + if self.record_stack_traces and not proxy.node.stack_trace: + proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) + + + return proxy + + def _find_user_frame(self): + """ + Find the Python stack frame executing the user code during + symbolic tracing. + """ + # We have to do a little dance here. Basically, walk up the callstack and + # record the first frame not in the pytorch source. This is the frame executing + # the user code during tracing. + frame = inspect.currentframe() + + pt_files = ['torch/fx/proxy.py', + 'torch/fx/_symbolic_trace.py', + 'torch/fx/experimental/proxy_tensor.py', + 'torch/_ops.py', + 'torch/_tensor.py', + 'torch/utils/_python_dispatch.py', + 'torch/_prims_common/wrappers.py', + 'torch/_refs/__init__.py', + 'torch/_refs/nn/functional/__init__.py', + 'torch/utils/_stats.py', + ] + while frame: + frame = frame.f_back + if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): + break + + if not frame: + return None + + return frame + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ + if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): + return a.__fx_create_arg__(self) + # aggregates + elif isinstance(a, tuple) and hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = tuple(self.create_arg(elem) for elem in a) + return type(a)(*args) # type: ignore[arg-type] + elif isinstance(a, (tuple, list)): + return type(a)(self.create_arg(elem) for elem in a) + elif isinstance(a, dict): + r = {} + for k, v in a.items(): + # Check for invalid dict keys. We do not want a Proxy to appear + # anywhere within the key. Since keys can be collection types, + # we iterate through the key with map_aggregate + k = self.create_arg(k) + + def no_node(arg): + if isinstance(arg, Node): + raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}") + map_aggregate(k, no_node) + + r[k] = self.create_arg(v) + return r + elif isinstance(a, slice): + return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, range): + return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + + elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return a + + if isinstance(a, Proxy): + # base case: we unwrap the Proxy object + return a.node + + if is_dataclass(a): + kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} + return self.create_node("call_function", a.__class__, (), kwargs) + + elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: + return a + raise NotImplementedError(f"argument of type: {type(a)}") + + @compatibility(is_backward_compatible=True) + def to_bool(self, obj: 'Proxy') -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + + @compatibility(is_backward_compatible=True) + def iter(self, obj: 'Proxy') -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError('Proxy object cannot be iterated. This can be ' + 'attempted when the Proxy is used in a loop or' + ' as a *args or **kwargs function argument. ' + 'See the torch.fx docs on pytorch.org for a ' + 'more detailed explanation of what types of ' + 'control flow can be traced, and check out the' + ' Proxy docstring for help troubleshooting ' + 'Proxy iteration errors') + + @compatibility(is_backward_compatible=True) + def keys(self, obj: 'Proxy') -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, 'keys')() + + +# used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) +class GraphAppendingTracer(TracerBase): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} + +@compatibility(is_backward_compatible=False) +def assert_fn(x): + assert x + +@compatibility(is_backward_compatible=True) +class TraceError(ValueError): + pass + +@compatibility(is_backward_compatible=True) +class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + + ``Proxy`` objects cannot be iterated. In other words, the symbolic + tracer will throw an error if a ``Proxy`` is used in a loop or as + an ``*args``/``**kwargs`` function argument. + + There are two main ways around this: + 1. Factor out the untraceable logic into a top-level function and + use ``fx.wrap`` on it. + 2. If the control flow is static (i.e. the loop trip count is + based on some hyperparameter), the code can be kept in its original + position and refactored into something like:: + + for i in range(self.some_hyperparameter): + indexed_item = proxied_value[i] + + For a more detailed description into the Proxy internals, check out + the "Proxy" section in `torch/fx/README.md` + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): + if tracer is None: + # This allows you to create a Proxy object around a raw Node + tracer = GraphAppendingTracer(node.graph) + self.tracer = tracer + self.node = node + + def __repr__(self) -> str: + return f'Proxy({self.node.name})' + + def __getattr__(self, k) -> 'Attribute': + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return Attribute(self, k) + + def __getstate__(self) -> Dict: + return self.__dict__ + + def __deepcopy__(self, memo) -> Dict: + # We have to explicitly override this method, because otherwise deepcopy + # will go to __getattr__(self, "__deepcopy__") and return a + # Attribute(__deepcopy__), and may go into an infinite loop in some cases. + import copy + new_dict = {} + for k, v in self.__dict__.items(): + try: + new_obj = copy.deepcopy(v, memo) + except Exception: + log.warning( + "Shallow copy %s of Proxy because it cannot be deepcopied. " + "Proxy is created for node %s", k, self.node.name) + new_obj = copy.copy(v) + new_dict[k] = new_obj + assert "node" in new_dict + assert "tracer" in new_dict + new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) + for k, v in new_dict.items(): + new_proxy.__dict__[k] = v + return new_proxy + + def __setstate__(self, d): + # This is called when being unpickled/loaded. + self.__dict__ = d + + def __call__(self, *args, **kwargs) -> 'Proxy': + return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + + def __iter__(self) -> Iterator['Proxy']: + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + inst_list = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) + else: + inst_idx = calling_frame.f_lasti // 2 + inst = inst_list[inst_idx] + if inst.opname == 'UNPACK_SEQUENCE': + return (self[i] for i in range(inst.argval)) # type: ignore[index] + + return self.tracer.iter(self) + + def __abs__(self): + return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) + + def __bool__(self) -> bool: + if self.tracer.trace_asserts: + # check if this boolean is used in an assertion, bytecode pattern for assertions + # is pretty stable for Python 3.7--3.9 + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + insts = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + cur = calling_frame.f_lasti // 2 + inst = insts[cur] + + if inst.opname == 'POP_JUMP_IF_TRUE': + first = insts[cur + 1] + assert inst.arg is not None + last = insts[inst.arg // 2 - 1] + starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' + or first.opname == 'LOAD_ASSERTION_ERROR') + if starts_with_assert and last.opname == 'RAISE_VARARGS': + self.tracer.create_proxy('call_function', assert_fn, (self,), {}) + return True + + return self.tracer.to_bool(self) + + @compatibility(is_backward_compatible=True) + def keys(self): + return self.tracer.keys(self) + + def __len__(self): + raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope") + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers : Dict[Any, None] = {} + + def find_tracer(a): + if isinstance(a, cls): + tracers[a.tracer] = None + torch.fx.node.map_aggregate(args, find_tracer) + torch.fx.node.map_aggregate(kwargs, find_tracer) + + if len(tracers) > 1: + raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' + f'trying to trace operations {orig_method}') + tracer = next(iter(tracers.keys())) + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + if torch.overrides.is_tensor_method_or_property(orig_method): + return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + else: + if isinstance(orig_method, torch._ops.HigherOrderOperator): + # TODO: Define how to symbolically trace HigherOrderOperators + raise RuntimeError("Unable to symbolically trace HigherOrderOperators") + return tracer.create_proxy('call_function', orig_method, args, kwargs, + name=tracer.graph._target_to_str(orig_method.__name__)) + + +@compatibility(is_backward_compatible=True) +class Attribute(Proxy): + @compatibility(is_backward_compatible=True) + def __init__(self, root: Proxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + + +@compatibility(is_backward_compatible=False) +class ParameterProxy(Proxy): + """ + A special proxy which lets "shape", "size", "dim", and a few other + attribute accesses pass through to the underlying module parameter object, + so that conditional tests on these attributes will not throw exception during tracing + """ + def __init__(self, tracer: TracerBase, node: Node, name, param): + super().__init__(node, tracer) + assert isinstance(param, torch.nn.Parameter) + self.param = param + self.name = name + + def __repr__(self) -> str: + return f'ParameterProxy({self.name})' + + @property + def shape(self): + return self.param.shape + + def size(self): + return self.param.size() + + def dim(self): + return self.param.dim() + + @property + def ndim(self): + return self.param.ndim + + def numel(self): + return self.param.numel() + + def nelement(self): + return self.param.nelement() + + +for method in magic_methods: + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = getattr(operator, method) + return tracer.create_proxy('call_function', target, args, kwargs) + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(Proxy, as_magic, impl) + _scope(method) + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(Proxy, method_name, impl) + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py b/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..83b5a9f8faf65e86813d55926ef275fa19ef2013 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/tensor_type.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +class TensorType: + """ + TensorType defines a type for tensors, which consists of a list of dimensions. + Example: + class M(torch.nn.Module): + def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): + return torch.add(x, y) + """ + + def __init__(self, dim): + self.__origin__ = TensorType + self.__args__ = dim + + def __repr__(self): + return f'TensorType[{self.__args__}]' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return list(self.__args__) == list(other.__args__) + else: + return False + + @staticmethod + def __class_getitem__(*args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + return TensorType(tuple(args)) + + +class _DynType: + """ + _DynType defines a type which stands for the absence of type information. + """ + def __init__(self) -> None: + self.__name__ = '_DynType' + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __str__(self): + return "Dyn" + + def __repr__(self): + return "Dyn" + + +Dyn = _DynType() + +@compatibility(is_backward_compatible=False) +def is_consistent(t1, t2): + """ + A binary relation denoted by ~ that determines if t1 is consistent with t2. + The relation is reflexive, symmetric but not transitive. + returns True if t1 and t2 are consistent and False otherwise. + Example: + Dyn ~ TensorType((1,2,3)) + int ~ Dyn + int ~ int + TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) + """ + + if t1 == t2: + return True + + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + else: + return False + + +@compatibility(is_backward_compatible=False) +def is_more_precise(t1, t2): + """ + A binary relation denoted by <= that determines if t1 is more precise than t2. + The relation is reflexive and transitive. + returns True if t1 is more precise than t2 and False otherwise. + Example: + Dyn >= TensorType((1,2,3)) + int >= Dyn + int >= int + TensorType((1,Dyn,3)) <= TensorType((1,2,3)) + """ + if t1 == t2: + return True + + if isinstance(t2, _DynType): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and \ + all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + + else: + return False diff --git a/.venv/lib/python3.11/site-packages/torch/fx/traceback.py b/.venv/lib/python3.11/site-packages/torch/fx/traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..4e72a8011f63ab44a64b8ece6ccbe4b2875fdda6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/traceback.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +import traceback +from contextlib import contextmanager +from typing import List, Any, Dict +from ._compatibility import compatibility + +__all__ = ['preserve_node_meta', 'has_preserved_node_meta', + 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', + 'format_stack', 'set_current_meta', 'get_current_meta'] + +current_meta: Dict[str, Any] = {} +should_preserve_node_meta = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def preserve_node_meta(): + global should_preserve_node_meta + global current_meta + + saved_should_preserve_node_meta = should_preserve_node_meta + # Shallow copy is OK since fields of current_meta are not mutated + saved_current_meta = current_meta.copy() + try: + should_preserve_node_meta = True + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + current_meta = saved_current_meta + + +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack : List[str]): + global current_meta + + if should_preserve_node_meta and stack: + current_meta["stack_trace"] = "".join(stack) + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr): + global current_meta + + if should_preserve_node_meta: + # The seq_nr is captured by eager mode in the grad_fn during forward + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] + current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 + + +@compatibility(is_backward_compatible=False) +def reset_grad_fn_seq_nr(): + # NB: reset state properly, this would be helpful towards supporting + # reentrant autograd if we actually wanted to do that. + global current_meta + if should_preserve_node_meta: + current_level = current_meta.get("in_grad_fn", 0) + assert current_level > 0 + if current_level == 1: + del current_meta["in_grad_fn"] + del current_meta["grad_fn_seq_nr"] + else: + current_meta["in_grad_fn"] = current_level - 1 + current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] + + +@compatibility(is_backward_compatible=False) +def format_stack() -> List[str]: + if should_preserve_node_meta: + return [current_meta.get("stack_trace", "")] + else: + # fallback to traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) + + +@compatibility(is_backward_compatible=False) +def has_preserved_node_meta() -> bool: + return should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_meta(node): + global current_meta + if should_preserve_node_meta and node.meta: + saved_meta = current_meta + try: + current_meta = node.meta.copy() + + # Append (node.name, node.target) onto "from_node" for provenance tracking + if "from_node" not in current_meta: + current_meta["from_node"] = [(node.name, node.target)] + elif current_meta["from_node"][-1][0] != node.name: + current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] + + yield + finally: + current_meta = saved_meta + else: + yield + + +@compatibility(is_backward_compatible=False) +def get_current_meta() -> Dict[str, Any]: + return current_meta