Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .venv/lib/python3.11/site-packages/torch/_export/error.py +56 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc +3 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py +15 -0
- .venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py +321 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/__init__.py +89 -0
- .venv/lib/python3.11/site-packages/torch/fx/__init__.pyi +15 -0
- .venv/lib/python3.11/site-packages/torch/fx/_compatibility.py +36 -0
- .venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py +185 -0
- .venv/lib/python3.11/site-packages/torch/fx/_pytree.py +103 -0
- .venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py +1290 -0
- .venv/lib/python3.11/site-packages/torch/fx/_utils.py +63 -0
- .venv/lib/python3.11/site-packages/torch/fx/annotate.py +32 -0
- .venv/lib/python3.11/site-packages/torch/fx/config.py +6 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py +27 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py +88 -0
.gitattributes
CHANGED
|
@@ -126,3 +126,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
|
|
| 126 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 128 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 126 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
.venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 128 |
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
.venv/lib/python3.11/site-packages/torch/_export/error.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExportErrorType(Enum):
|
| 5 |
+
# User providing invalid inputs to either tracer, or other public facing APIs
|
| 6 |
+
INVALID_INPUT_TYPE = 1
|
| 7 |
+
|
| 8 |
+
# User returning values from their models that we don't support.
|
| 9 |
+
INVALID_OUTPUT_TYPE = 2
|
| 10 |
+
|
| 11 |
+
# Generated IR does not conform to Export IR Specification.
|
| 12 |
+
VIOLATION_OF_SPEC = 3
|
| 13 |
+
|
| 14 |
+
# User's code contains types and functionalities we don't support.
|
| 15 |
+
NOT_SUPPORTED = 4
|
| 16 |
+
|
| 17 |
+
# User's code didn't provide necessary details for us to successfully trace and export.
|
| 18 |
+
# For example, we use a lot of decorators and ask users to annotate their model.
|
| 19 |
+
MISSING_PROPERTY = 5
|
| 20 |
+
|
| 21 |
+
# User is using an API without proper initialization step.
|
| 22 |
+
UNINITIALIZED = 6
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def internal_assert(pred: bool, assert_msg: str) -> None:
|
| 26 |
+
"""
|
| 27 |
+
This is exir's custom assert method. It internally just throws InternalError.
|
| 28 |
+
Note that the sole purpose is to throw our own error while maintaining similar syntax
|
| 29 |
+
as python assert.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
if not pred:
|
| 33 |
+
raise InternalError(assert_msg)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class InternalError(Exception):
|
| 37 |
+
"""
|
| 38 |
+
Raised when an internal invariance is violated in EXIR stack.
|
| 39 |
+
Should hint users to report a bug to dev and expose the original
|
| 40 |
+
error message.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, message: str) -> None:
|
| 44 |
+
super().__init__(message)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ExportError(Exception):
|
| 48 |
+
"""
|
| 49 |
+
This type of exception is raised for errors that are directly caused by the user
|
| 50 |
+
code. In general, user errors happen during model authoring, tracing, using our public
|
| 51 |
+
facing APIs, and writing graph passes.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, error_code: ExportErrorType, message: str) -> None:
|
| 55 |
+
prefix = f"[{error_code}]: "
|
| 56 |
+
super().__init__(prefix + message)
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc131857ed1d25d734bce65ed9c8acad8c38ffb2614c7fcf51f2cbfebac196a1
|
| 3 |
+
size 164473
|
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc
ADDED
|
Binary file (5.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from torch._export.serde.schema import Node
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ExternKernelNode:
|
| 9 |
+
name: str
|
| 10 |
+
node: Node
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ExternKernelNodes:
|
| 15 |
+
nodes: List[ExternKernelNode]
|
.venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.exc import UserError, UserErrorType
|
| 6 |
+
from torch.export.dynamic_shapes import (
|
| 7 |
+
_check_dynamic_shapes,
|
| 8 |
+
_DerivedDim,
|
| 9 |
+
_Dim,
|
| 10 |
+
_DimHint,
|
| 11 |
+
_tree_map_with_path,
|
| 12 |
+
Dim,
|
| 13 |
+
)
|
| 14 |
+
from torch.utils._pytree import tree_map
|
| 15 |
+
|
| 16 |
+
from .serialize import _dataclass_to_dict
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclasses.dataclass
|
| 20 |
+
class RootDim:
|
| 21 |
+
"""
|
| 22 |
+
This represents a _Dim object.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
min: int
|
| 26 |
+
max: Union[int, None]
|
| 27 |
+
derived: List[str]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclasses.dataclass
|
| 31 |
+
class DynamicShapesSpec:
|
| 32 |
+
"""
|
| 33 |
+
This stores a dynamic_shapes spec for de/serialization.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
|
| 37 |
+
dims: Dict[str, RootDim]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _postprocess_serialized_shapes(
|
| 41 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
| 42 |
+
dims: Dict[str, Dict[str, Union[int, List[str], None]]],
|
| 43 |
+
to_dict: Optional[bool] = False,
|
| 44 |
+
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
|
| 45 |
+
"""
|
| 46 |
+
Sorts dims and dumps to dictionary format.
|
| 47 |
+
"""
|
| 48 |
+
from torch.utils._sympy.numbers import int_oo
|
| 49 |
+
|
| 50 |
+
dims = {
|
| 51 |
+
k: RootDim(
|
| 52 |
+
min=v["min"], # type: ignore[arg-type]
|
| 53 |
+
max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type]
|
| 54 |
+
derived=sorted(v["derived"]), # type: ignore[arg-type]
|
| 55 |
+
)
|
| 56 |
+
for k, v in sorted(dims.items())
|
| 57 |
+
}
|
| 58 |
+
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
|
| 59 |
+
if to_dict:
|
| 60 |
+
return _dataclass_to_dict(spec)
|
| 61 |
+
else:
|
| 62 |
+
return spec
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _dump_dynamic_shapes(
|
| 66 |
+
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
| 67 |
+
args: Tuple[Any],
|
| 68 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 69 |
+
to_dict: Optional[bool] = False,
|
| 70 |
+
) -> Union[DynamicShapesSpec, Dict[str, Any]]:
|
| 71 |
+
"""
|
| 72 |
+
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
|
| 73 |
+
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
|
| 74 |
+
Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
|
| 75 |
+
|
| 76 |
+
dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
|
| 77 |
+
- Each tensor input is represented with a list of values, non-tensor inputs with None.
|
| 78 |
+
- dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
|
| 79 |
+
- static dimensions are represented with ints.
|
| 80 |
+
|
| 81 |
+
dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
|
| 82 |
+
|
| 83 |
+
For example:
|
| 84 |
+
```
|
| 85 |
+
dx = Dim("dx", min=4, max=16)
|
| 86 |
+
dy = dx + 1
|
| 87 |
+
|
| 88 |
+
inputs = (
|
| 89 |
+
[
|
| 90 |
+
torch.randn(4, 4),
|
| 91 |
+
torch.randn(5, 4),
|
| 92 |
+
],
|
| 93 |
+
torch.randn(4),
|
| 94 |
+
torch.randn(4, 4),
|
| 95 |
+
"hello",
|
| 96 |
+
)
|
| 97 |
+
dynamic_shapes = {
|
| 98 |
+
"a": [
|
| 99 |
+
(dx, 4),
|
| 100 |
+
(dy, 4),
|
| 101 |
+
],
|
| 102 |
+
"b": (Dim.STATIC,),
|
| 103 |
+
"c": None,
|
| 104 |
+
"d": None,
|
| 105 |
+
}
|
| 106 |
+
out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
|
| 107 |
+
```
|
| 108 |
+
would generate the following output:
|
| 109 |
+
```
|
| 110 |
+
{
|
| 111 |
+
'dynamic_shapes': (
|
| 112 |
+
[
|
| 113 |
+
['dx', 4],
|
| 114 |
+
['dx + 1', 4],
|
| 115 |
+
],
|
| 116 |
+
['_DimHint.STATIC'],
|
| 117 |
+
['_DimHint.STATIC', '_DimHint.STATIC'],
|
| 118 |
+
None,
|
| 119 |
+
),
|
| 120 |
+
'dims': {
|
| 121 |
+
'dx': {
|
| 122 |
+
'min': 4,
|
| 123 |
+
'max': 16,
|
| 124 |
+
'derived': ['dx + 1'],
|
| 125 |
+
},
|
| 126 |
+
},
|
| 127 |
+
}
|
| 128 |
+
```
|
| 129 |
+
"""
|
| 130 |
+
dims: Dict[str, Dict[str, Any]] = {}
|
| 131 |
+
|
| 132 |
+
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
|
| 133 |
+
"""
|
| 134 |
+
Helps standardize the dynamic_shapes tree structure we serialize,
|
| 135 |
+
returning lists for each tensor shape, handling tensor-level Nones.
|
| 136 |
+
"""
|
| 137 |
+
if not isinstance(tensor, torch.Tensor):
|
| 138 |
+
return None
|
| 139 |
+
if shape is None:
|
| 140 |
+
return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined]
|
| 141 |
+
|
| 142 |
+
out = []
|
| 143 |
+
if isinstance(shape, dict):
|
| 144 |
+
for i, s in enumerate(tensor.shape):
|
| 145 |
+
out.append(s if shape.get(i) is None else shape.get(i))
|
| 146 |
+
else:
|
| 147 |
+
assert isinstance(shape, (tuple, list))
|
| 148 |
+
for i, s in enumerate(tensor.shape):
|
| 149 |
+
out.append(s if shape[i] is None else shape[i])
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
def _track_dim_from_dims(
|
| 153 |
+
val: Union[None, int, _DimHint, _Dim]
|
| 154 |
+
) -> Union[None, int, str]:
|
| 155 |
+
"""
|
| 156 |
+
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
|
| 157 |
+
"""
|
| 158 |
+
if val is None or isinstance(val, int): # non-tensor input or static
|
| 159 |
+
return val
|
| 160 |
+
if isinstance(val, _DimHint): # store enum as string
|
| 161 |
+
return val.__class__.__name__ + "." + val.name
|
| 162 |
+
|
| 163 |
+
assert isinstance(val, _Dim)
|
| 164 |
+
|
| 165 |
+
# track root dim
|
| 166 |
+
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
|
| 167 |
+
if root.__name__ not in dims:
|
| 168 |
+
dims[root.__name__] = {
|
| 169 |
+
"min": root.min,
|
| 170 |
+
"max": root.max,
|
| 171 |
+
"derived": set(),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# track derived dims
|
| 175 |
+
if isinstance(val, _DerivedDim):
|
| 176 |
+
dims[root.__name__]["derived"].add(val.__name__)
|
| 177 |
+
|
| 178 |
+
return val.__name__
|
| 179 |
+
|
| 180 |
+
if dynamic_shapes is None:
|
| 181 |
+
return {"dynamic_shapes": None, "dims": {}}
|
| 182 |
+
|
| 183 |
+
# convert to tuple of specs, for each arg/kwarg
|
| 184 |
+
kwargs = kwargs or {}
|
| 185 |
+
if isinstance(dynamic_shapes, dict):
|
| 186 |
+
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
|
| 187 |
+
dynamic_shapes = tuple(dynamic_shapes)
|
| 188 |
+
combined_args = tuple(args) + tuple(kwargs.values())
|
| 189 |
+
|
| 190 |
+
# run same check when we're processing shapes for export - is this too lazy?
|
| 191 |
+
_check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type]
|
| 192 |
+
|
| 193 |
+
tree_shapes = _tree_map_with_path(
|
| 194 |
+
_standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
|
| 195 |
+
)
|
| 196 |
+
serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
|
| 197 |
+
return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _load_dynamic_shapes(
|
| 201 |
+
spec: Union[DynamicShapesSpec, Dict[str, Any]],
|
| 202 |
+
from_dict: Optional[bool] = False,
|
| 203 |
+
) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]:
|
| 204 |
+
"""
|
| 205 |
+
Utility function for dynamic shapes serialization.
|
| 206 |
+
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
|
| 207 |
+
"""
|
| 208 |
+
import sympy
|
| 209 |
+
|
| 210 |
+
from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
|
| 211 |
+
|
| 212 |
+
if from_dict:
|
| 213 |
+
if not isinstance(spec, dict):
|
| 214 |
+
raise UserError(
|
| 215 |
+
UserErrorType.INVALID_INPUT,
|
| 216 |
+
f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
|
| 217 |
+
)
|
| 218 |
+
if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
|
| 219 |
+
raise UserError(
|
| 220 |
+
UserErrorType.INVALID_INPUT,
|
| 221 |
+
"With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
|
| 222 |
+
f"instead found {spec.keys()}",
|
| 223 |
+
)
|
| 224 |
+
dims = {}
|
| 225 |
+
for k, v in spec["dims"].items():
|
| 226 |
+
if not isinstance(k, str):
|
| 227 |
+
raise UserError(
|
| 228 |
+
UserErrorType.INVALID_INPUT,
|
| 229 |
+
f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
|
| 230 |
+
)
|
| 231 |
+
if sorted(v.keys()) != ["derived", "max", "min"]:
|
| 232 |
+
raise UserError(
|
| 233 |
+
UserErrorType.INVALID_INPUT,
|
| 234 |
+
f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
|
| 235 |
+
f"instead found {v.keys()}",
|
| 236 |
+
)
|
| 237 |
+
if not isinstance(v["min"], int):
|
| 238 |
+
raise UserError(
|
| 239 |
+
UserErrorType.INVALID_INPUT,
|
| 240 |
+
f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
|
| 241 |
+
)
|
| 242 |
+
if not isinstance(v["max"], int) or v["max"] is None:
|
| 243 |
+
raise UserError(
|
| 244 |
+
UserErrorType.INVALID_INPUT,
|
| 245 |
+
f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
|
| 246 |
+
)
|
| 247 |
+
if not isinstance(v["derived"], list) or any(
|
| 248 |
+
not isinstance(d, str) for d in v["derived"]
|
| 249 |
+
):
|
| 250 |
+
raise UserError(
|
| 251 |
+
UserErrorType.INVALID_INPUT,
|
| 252 |
+
"Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
|
| 253 |
+
f"got {k}: {v['derived']}",
|
| 254 |
+
)
|
| 255 |
+
dims[k] = RootDim(**v)
|
| 256 |
+
dynamic_shapes = spec["dynamic_shapes"]
|
| 257 |
+
else:
|
| 258 |
+
if not isinstance(spec, DynamicShapesSpec):
|
| 259 |
+
raise UserError(
|
| 260 |
+
UserErrorType.INVALID_INPUT,
|
| 261 |
+
f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
|
| 262 |
+
)
|
| 263 |
+
dims = spec.dims
|
| 264 |
+
dynamic_shapes = spec.dynamic_shapes
|
| 265 |
+
|
| 266 |
+
if dynamic_shapes is None:
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
dim_cache = {}
|
| 270 |
+
for name, info in dims.items():
|
| 271 |
+
symbol = sympy.sympify(name)
|
| 272 |
+
if not isinstance(symbol, sympy.Symbol):
|
| 273 |
+
raise UserError(
|
| 274 |
+
UserErrorType.INVALID_INPUT,
|
| 275 |
+
f"Expected `spec['dims']` keys to be symbols, got {name}",
|
| 276 |
+
)
|
| 277 |
+
dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim
|
| 278 |
+
for _expr in info.derived:
|
| 279 |
+
expr = sympy.sympify(_expr)
|
| 280 |
+
if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
|
| 281 |
+
raise UserError(
|
| 282 |
+
UserErrorType.INVALID_INPUT,
|
| 283 |
+
f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
|
| 284 |
+
)
|
| 285 |
+
if not _is_supported_equivalence(expr):
|
| 286 |
+
raise UserError(
|
| 287 |
+
UserErrorType.INVALID_INPUT,
|
| 288 |
+
f"Expected derived expressions to be linear expressions, got {expr}",
|
| 289 |
+
)
|
| 290 |
+
modulus, remainder = sympy.polys.polytools.div(expr, symbol)
|
| 291 |
+
ddim = dim_cache[name]
|
| 292 |
+
if modulus != 1:
|
| 293 |
+
ddim = int(modulus) * ddim
|
| 294 |
+
if remainder != 0:
|
| 295 |
+
ddim = ddim + int(remainder)
|
| 296 |
+
dim_cache[_expr] = ddim # cache derived dims
|
| 297 |
+
|
| 298 |
+
def deserialize_shape(
|
| 299 |
+
val: Union[None, int, str]
|
| 300 |
+
) -> Union[None, int, _Dim, _DimHint]:
|
| 301 |
+
if val is None or isinstance(val, int):
|
| 302 |
+
return val
|
| 303 |
+
elif val == "_DimHint.AUTO":
|
| 304 |
+
return _DimHint.AUTO
|
| 305 |
+
elif val == "_DimHint.STATIC":
|
| 306 |
+
return _DimHint.STATIC
|
| 307 |
+
if not isinstance(val, str):
|
| 308 |
+
raise UserError(
|
| 309 |
+
UserErrorType.INVALID_INPUT,
|
| 310 |
+
"Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
|
| 311 |
+
f" or derived expressions, got {val}",
|
| 312 |
+
)
|
| 313 |
+
if val not in dim_cache:
|
| 314 |
+
raise UserError(
|
| 315 |
+
UserErrorType.INVALID_INPUT,
|
| 316 |
+
"Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
|
| 317 |
+
f"got {val} which is not in {dims.keys()}",
|
| 318 |
+
)
|
| 319 |
+
return dim_cache[val]
|
| 320 |
+
|
| 321 |
+
return tree_map(deserialize_shape, dynamic_shapes)
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc
ADDED
|
Binary file (8.08 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/__init__.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r'''
|
| 2 |
+
FX is a toolkit for developers to use to transform ``nn.Module``
|
| 3 |
+
instances. FX consists of three main components: a **symbolic tracer,**
|
| 4 |
+
an **intermediate representation**, and **Python code generation**. A
|
| 5 |
+
demonstration of these components in action:
|
| 6 |
+
|
| 7 |
+
::
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
# Simple module for demonstration
|
| 11 |
+
class MyModule(torch.nn.Module):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
| 15 |
+
self.linear = torch.nn.Linear(4, 5)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
| 19 |
+
|
| 20 |
+
module = MyModule()
|
| 21 |
+
|
| 22 |
+
from torch.fx import symbolic_trace
|
| 23 |
+
# Symbolic tracing frontend - captures the semantics of the module
|
| 24 |
+
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
|
| 25 |
+
|
| 26 |
+
# High-level intermediate representation (IR) - Graph representation
|
| 27 |
+
print(symbolic_traced.graph)
|
| 28 |
+
"""
|
| 29 |
+
graph():
|
| 30 |
+
%x : [num_users=1] = placeholder[target=x]
|
| 31 |
+
%param : [num_users=1] = get_attr[target=param]
|
| 32 |
+
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
|
| 33 |
+
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
|
| 34 |
+
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
|
| 35 |
+
return clamp
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Code generation - valid Python code
|
| 39 |
+
print(symbolic_traced.code)
|
| 40 |
+
"""
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
param = self.param
|
| 43 |
+
add = x + param; x = param = None
|
| 44 |
+
linear = self.linear(add); add = None
|
| 45 |
+
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
|
| 46 |
+
return clamp
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
The **symbolic tracer** performs "symbolic execution" of the Python
|
| 50 |
+
code. It feeds fake values, called Proxies, through the code. Operations
|
| 51 |
+
on theses Proxies are recorded. More information about symbolic tracing
|
| 52 |
+
can be found in the :func:`symbolic_trace` and :class:`Tracer`
|
| 53 |
+
documentation.
|
| 54 |
+
|
| 55 |
+
The **intermediate representation** is the container for the operations
|
| 56 |
+
that were recorded during symbolic tracing. It consists of a list of
|
| 57 |
+
Nodes that represent function inputs, callsites (to functions, methods,
|
| 58 |
+
or :class:`torch.nn.Module` instances), and return values. More information
|
| 59 |
+
about the IR can be found in the documentation for :class:`Graph`. The
|
| 60 |
+
IR is the format on which transformations are applied.
|
| 61 |
+
|
| 62 |
+
**Python code generation** is what makes FX a Python-to-Python (or
|
| 63 |
+
Module-to-Module) transformation toolkit. For each Graph IR, we can
|
| 64 |
+
create valid Python code matching the Graph's semantics. This
|
| 65 |
+
functionality is wrapped up in :class:`GraphModule`, which is a
|
| 66 |
+
:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
|
| 67 |
+
``forward`` method generated from the Graph.
|
| 68 |
+
|
| 69 |
+
Taken together, this pipeline of components (symbolic tracing ->
|
| 70 |
+
intermediate representation -> transforms -> Python code generation)
|
| 71 |
+
constitutes the Python-to-Python transformation pipeline of FX. In
|
| 72 |
+
addition, these components can be used separately. For example,
|
| 73 |
+
symbolic tracing can be used in isolation to capture a form of
|
| 74 |
+
the code for analysis (and not transformation) purposes. Code
|
| 75 |
+
generation can be used for programmatically generating models, for
|
| 76 |
+
example from a config file. There are many uses for FX!
|
| 77 |
+
|
| 78 |
+
Several example transformations can be found at the
|
| 79 |
+
`examples <https://github.com/pytorch/examples/tree/master/fx>`__
|
| 80 |
+
repository.
|
| 81 |
+
'''
|
| 82 |
+
|
| 83 |
+
from .graph_module import GraphModule
|
| 84 |
+
from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
|
| 85 |
+
from .graph import Graph, CodeGen
|
| 86 |
+
from .node import Node, map_arg, has_side_effect
|
| 87 |
+
from .proxy import Proxy
|
| 88 |
+
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
|
| 89 |
+
from .subgraph_rewriter import replace_pattern
|
.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx._symbolic_trace import (
|
| 2 |
+
symbolic_trace as symbolic_trace,
|
| 3 |
+
Tracer as Tracer,
|
| 4 |
+
wrap as wrap,
|
| 5 |
+
)
|
| 6 |
+
from torch.fx.graph import Graph as Graph
|
| 7 |
+
from torch.fx.graph_module import GraphModule as GraphModule
|
| 8 |
+
from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
|
| 9 |
+
from torch.fx.node import (
|
| 10 |
+
has_side_effect as has_side_effect,
|
| 11 |
+
map_arg as map_arg,
|
| 12 |
+
Node as Node,
|
| 13 |
+
)
|
| 14 |
+
from torch.fx.proxy import Proxy as Proxy
|
| 15 |
+
from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern
|
.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Callable, TypeVar
|
| 2 |
+
import textwrap
|
| 3 |
+
|
| 4 |
+
_BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
|
| 5 |
+
_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
|
| 6 |
+
|
| 7 |
+
_T = TypeVar("_T")
|
| 8 |
+
|
| 9 |
+
def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
|
| 10 |
+
if is_backward_compatible:
|
| 11 |
+
|
| 12 |
+
def mark_back_compat(fn: _T) -> _T:
|
| 13 |
+
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
| 14 |
+
docstring += """
|
| 15 |
+
.. note::
|
| 16 |
+
Backwards-compatibility for this API is guaranteed.
|
| 17 |
+
"""
|
| 18 |
+
fn.__doc__ = docstring
|
| 19 |
+
_BACK_COMPAT_OBJECTS.setdefault(fn)
|
| 20 |
+
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
|
| 21 |
+
return fn
|
| 22 |
+
|
| 23 |
+
return mark_back_compat
|
| 24 |
+
else:
|
| 25 |
+
|
| 26 |
+
def mark_not_back_compat(fn: _T) -> _T:
|
| 27 |
+
docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
|
| 28 |
+
docstring += """
|
| 29 |
+
.. warning::
|
| 30 |
+
This API is experimental and is *NOT* backward-compatible.
|
| 31 |
+
"""
|
| 32 |
+
fn.__doc__ = docstring
|
| 33 |
+
_MARKED_WITH_COMPATIBILITY.setdefault(fn)
|
| 34 |
+
return fn
|
| 35 |
+
|
| 36 |
+
return mark_not_back_compat
|
.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
from torch.fx import GraphModule
|
| 5 |
+
from torch.fx.graph_module import (
|
| 6 |
+
_format_import_block,
|
| 7 |
+
reduce_graph_module,
|
| 8 |
+
reduce_package_graph_module,
|
| 9 |
+
)
|
| 10 |
+
from torch.package import PackageExporter, sys_importer
|
| 11 |
+
|
| 12 |
+
from ._compatibility import compatibility
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_use_lazy_graph_module_flag = False
|
| 16 |
+
_force_skip_lazy_graph_module_flag = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@compatibility(is_backward_compatible=False)
|
| 20 |
+
@contextmanager
|
| 21 |
+
def _force_skip_lazy_graph_module():
|
| 22 |
+
"""
|
| 23 |
+
Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
|
| 24 |
+
Use to skip _LazyGraphModule when testing inductor torchscript related backend.
|
| 25 |
+
|
| 26 |
+
torch.jit.script a _LazyGraphModule results in following error:
|
| 27 |
+
https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
global _force_skip_lazy_graph_module_flag
|
| 31 |
+
prior = _force_skip_lazy_graph_module_flag
|
| 32 |
+
_force_skip_lazy_graph_module_flag = True
|
| 33 |
+
yield
|
| 34 |
+
finally:
|
| 35 |
+
_force_skip_lazy_graph_module_flag = prior
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@compatibility(is_backward_compatible=False)
|
| 39 |
+
@contextmanager
|
| 40 |
+
def _use_lazy_graph_module(should_use: bool):
|
| 41 |
+
try:
|
| 42 |
+
global _use_lazy_graph_module_flag
|
| 43 |
+
prior = _use_lazy_graph_module_flag
|
| 44 |
+
_use_lazy_graph_module_flag = (
|
| 45 |
+
should_use and not _force_skip_lazy_graph_module_flag
|
| 46 |
+
)
|
| 47 |
+
yield
|
| 48 |
+
finally:
|
| 49 |
+
_use_lazy_graph_module_flag = prior
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@compatibility(is_backward_compatible=False)
|
| 53 |
+
def _get_graph_module_cls():
|
| 54 |
+
return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _make_graph_module(*args, graph_module_cls=None, **kwargs):
|
| 58 |
+
if graph_module_cls is None:
|
| 59 |
+
graph_module_cls = _get_graph_module_cls()
|
| 60 |
+
|
| 61 |
+
return graph_module_cls(*args, **kwargs)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@compatibility(is_backward_compatible=False)
|
| 65 |
+
class _LazyGraphModule(GraphModule):
|
| 66 |
+
"""
|
| 67 |
+
The main difference between _LazyGraphModule and GraphModule is how recompile happens.
|
| 68 |
+
GraphModule will do a 'recompile' call to generate python code and the forward method when it's
|
| 69 |
+
constructed. Later on if the graph get updated, recompile method can be called again to refresh
|
| 70 |
+
the saved python code and forward method.
|
| 71 |
+
|
| 72 |
+
However in some cases especially in inductor, the recompilation can be a waste since we never
|
| 73 |
+
check the python code for the graph module or call its forward method. A few more concreate
|
| 74 |
+
examples regarding pattern matching fx passes in inductor:
|
| 75 |
+
1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
|
| 76 |
+
2. some passes will trace small pattern function to search it in the graph being compiled and
|
| 77 |
+
replace the match with the traced graph of a replacement function. The pattern graph and
|
| 78 |
+
replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
|
| 79 |
+
for them in GraphModule.__init__ is also a waste of time.
|
| 80 |
+
|
| 81 |
+
However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
|
| 82 |
+
People may want to check the python code or call the GraphModule's forward method for debugging purposes.
|
| 83 |
+
|
| 84 |
+
The way _LazyGraphModule solves it is, we override the recompile method to just mark the
|
| 85 |
+
need for recompilation but does not do the actual recompilation. Later on if people really
|
| 86 |
+
access the compiled python code or call the GraphModule's forward method, we do the real
|
| 87 |
+
recompilation.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def from_graphmodule(cls, gm: GraphModule):
|
| 92 |
+
if isinstance(gm, _LazyGraphModule):
|
| 93 |
+
return gm
|
| 94 |
+
else:
|
| 95 |
+
return _LazyGraphModule(gm, gm.graph)
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def force_recompile(gm):
|
| 99 |
+
"""
|
| 100 |
+
Sometimes we need force a recompile as a workaround
|
| 101 |
+
- we want to do the real recompilation before symbolic_trace to avoid error:
|
| 102 |
+
https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
|
| 103 |
+
"""
|
| 104 |
+
if isinstance(gm, _LazyGraphModule):
|
| 105 |
+
gm.real_recompile()
|
| 106 |
+
|
| 107 |
+
def real_recompile(self):
|
| 108 |
+
if self._needs_recompile():
|
| 109 |
+
self._real_recompile()
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def _needs_recompile(cls):
|
| 113 |
+
return cls.forward is cls._lazy_forward
|
| 114 |
+
|
| 115 |
+
def _lazy_forward(self, *args, **kwargs):
|
| 116 |
+
# Call self.real_recompile() rather than self._real_recompile() here.
|
| 117 |
+
# The _lazy_forward method may be saved and call repeatedly.
|
| 118 |
+
# Calling self.real_recompile can make sure we skip recompilation if
|
| 119 |
+
# we have already done so.
|
| 120 |
+
self.real_recompile()
|
| 121 |
+
assert not self._needs_recompile()
|
| 122 |
+
|
| 123 |
+
# call `__call__` rather than 'forward' since recompilation may
|
| 124 |
+
# install a wrapper for `__call__` to provide a customized error
|
| 125 |
+
# message.
|
| 126 |
+
return self(*args, **kwargs)
|
| 127 |
+
|
| 128 |
+
forward = _lazy_forward
|
| 129 |
+
|
| 130 |
+
# TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
|
| 131 |
+
# or __reduce__ by calling _real_recompile. But I don't find a good way
|
| 132 |
+
# to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
|
| 133 |
+
# will be used in torch::deploy. So it's skipped for now.
|
| 134 |
+
|
| 135 |
+
def __reduce_package__(self, exporter: PackageExporter):
|
| 136 |
+
"""
|
| 137 |
+
Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
|
| 138 |
+
than 'self.recompile' since for a _LazyGraphModule, self.recompile just
|
| 139 |
+
mark the need of recompilation and does not return the PythonCode object.
|
| 140 |
+
"""
|
| 141 |
+
python_code = self._real_recompile()
|
| 142 |
+
dict_without_graph = self.__dict__.copy()
|
| 143 |
+
dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
|
| 144 |
+
del dict_without_graph["_graph"]
|
| 145 |
+
|
| 146 |
+
generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
|
| 147 |
+
import_block = _format_import_block(python_code.globals, exporter.importer)
|
| 148 |
+
module_code = import_block + self.code
|
| 149 |
+
exporter.save_source_string(generated_module_name, module_code)
|
| 150 |
+
return (
|
| 151 |
+
reduce_package_graph_module,
|
| 152 |
+
(dict_without_graph, generated_module_name),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def __reduce__(self):
|
| 156 |
+
"""
|
| 157 |
+
Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
|
| 158 |
+
than 'self.recompile' since for a _LazyGraphModule, self.recompile just
|
| 159 |
+
mark the need of recompilation and does not return the PythonCode object.
|
| 160 |
+
"""
|
| 161 |
+
python_code = self._real_recompile()
|
| 162 |
+
dict_without_graph = self.__dict__.copy()
|
| 163 |
+
import_block = _format_import_block(python_code.globals, sys_importer)
|
| 164 |
+
del dict_without_graph["_graph"]
|
| 165 |
+
return (reduce_graph_module, (dict_without_graph, import_block))
|
| 166 |
+
|
| 167 |
+
def _real_recompile(self):
|
| 168 |
+
return super().recompile()
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def recompile(cls):
|
| 172 |
+
cls.forward = cls._lazy_forward
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def code(self) -> str:
|
| 176 |
+
self.real_recompile()
|
| 177 |
+
return super().code
|
| 178 |
+
|
| 179 |
+
def __str__(self) -> str:
|
| 180 |
+
"""
|
| 181 |
+
str(GraphModule) will access the _code attribute. Make sure recompile
|
| 182 |
+
happens so _code attribute is available.
|
| 183 |
+
"""
|
| 184 |
+
self.real_recompile()
|
| 185 |
+
return super().__str__()
|
.venv/lib/python3.11/site-packages/torch/fx/_pytree.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
|
| 4 |
+
|
| 5 |
+
import torch.return_types
|
| 6 |
+
from torch.utils._pytree import PyTree, TreeSpec
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
|
| 10 |
+
FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
|
| 11 |
+
|
| 12 |
+
SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
|
| 13 |
+
SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def register_pytree_flatten_spec(
|
| 17 |
+
cls: Type[Any],
|
| 18 |
+
flatten_fn_spec: FlattenFuncSpec,
|
| 19 |
+
flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
|
| 20 |
+
) -> None:
|
| 21 |
+
SUPPORTED_NODES[cls] = flatten_fn_spec
|
| 22 |
+
SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def tree_flatten_spec(
|
| 26 |
+
pytree: PyTree,
|
| 27 |
+
spec: TreeSpec,
|
| 28 |
+
exact_structural_match=False,
|
| 29 |
+
) -> List[Any]:
|
| 30 |
+
if spec.is_leaf():
|
| 31 |
+
return [pytree]
|
| 32 |
+
if spec.type not in SUPPORTED_NODES:
|
| 33 |
+
raise RuntimeError(
|
| 34 |
+
f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
|
| 35 |
+
"torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make "
|
| 36 |
+
"sure that any custom pytrees have been registered before loading it.",
|
| 37 |
+
)
|
| 38 |
+
flatten_fn_spec = SUPPORTED_NODES[spec.type]
|
| 39 |
+
child_pytrees = flatten_fn_spec(pytree, spec)
|
| 40 |
+
if exact_structural_match:
|
| 41 |
+
flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
|
| 42 |
+
if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
|
| 43 |
+
pytree,
|
| 44 |
+
spec,
|
| 45 |
+
):
|
| 46 |
+
raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
|
| 47 |
+
result = []
|
| 48 |
+
for child, child_spec in zip(child_pytrees, spec.children_specs):
|
| 49 |
+
flat = tree_flatten_spec(child, child_spec, exact_structural_match)
|
| 50 |
+
result += flat
|
| 51 |
+
return result
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
|
| 55 |
+
return [d[k] for k in spec.context]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
|
| 59 |
+
return [d[i] for i in range(spec.num_children)]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
|
| 63 |
+
return [d[i] for i in range(spec.num_children)]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
|
| 67 |
+
return [d[i] for i in range(spec.num_children)]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
|
| 71 |
+
return len(d) == spec.num_children
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
|
| 75 |
+
return len(d) == spec.num_children
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
|
| 79 |
+
return len(d) == spec.num_children
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
|
| 83 |
+
return len(d) == spec.num_children
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
|
| 87 |
+
register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
|
| 88 |
+
register_pytree_flatten_spec(
|
| 89 |
+
tuple,
|
| 90 |
+
_tuple_flatten_spec,
|
| 91 |
+
_tuple_flatten_spec_exact_match,
|
| 92 |
+
)
|
| 93 |
+
for return_type in torch.return_types.all_return_types:
|
| 94 |
+
register_pytree_flatten_spec(
|
| 95 |
+
return_type,
|
| 96 |
+
_tuple_flatten_spec,
|
| 97 |
+
_tuple_flatten_spec_exact_match,
|
| 98 |
+
)
|
| 99 |
+
register_pytree_flatten_spec(
|
| 100 |
+
namedtuple, # type: ignore[arg-type]
|
| 101 |
+
_namedtuple_flatten_spec,
|
| 102 |
+
_namedtuple_flatten_spec_exact_match,
|
| 103 |
+
)
|
.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py
ADDED
|
@@ -0,0 +1,1290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import builtins
|
| 3 |
+
import copy
|
| 4 |
+
import contextlib
|
| 5 |
+
import functools
|
| 6 |
+
import inspect
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import warnings
|
| 10 |
+
import collections
|
| 11 |
+
from itertools import chain
|
| 12 |
+
from types import CodeType, FunctionType, ModuleType
|
| 13 |
+
from typing import (
|
| 14 |
+
Any,
|
| 15 |
+
Callable,
|
| 16 |
+
Dict,
|
| 17 |
+
List,
|
| 18 |
+
NamedTuple,
|
| 19 |
+
Optional,
|
| 20 |
+
Set,
|
| 21 |
+
Tuple,
|
| 22 |
+
Type,
|
| 23 |
+
Union,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.utils._pytree as pytree
|
| 28 |
+
from torch._C import ScriptObject # type: ignore[attr-defined]
|
| 29 |
+
from torch._library.fake_class_registry import FakeScriptObject
|
| 30 |
+
|
| 31 |
+
from ._compatibility import compatibility
|
| 32 |
+
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
|
| 33 |
+
from .graph_module import GraphModule
|
| 34 |
+
from ._lazy_graph_module import _make_graph_module
|
| 35 |
+
from .node import Argument, base_types, map_aggregate
|
| 36 |
+
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
|
| 37 |
+
|
| 38 |
+
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
|
| 39 |
+
|
| 40 |
+
# These need to run in global scope to handle nested calls correctly
|
| 41 |
+
_orig_module_call: Callable = torch.nn.Module.__call__
|
| 42 |
+
_orig_module_getattr: Callable = torch.nn.Module.__getattr__
|
| 43 |
+
|
| 44 |
+
_proxyable_classes: Dict[Type, None] = {}
|
| 45 |
+
|
| 46 |
+
_is_fx_tracing_flag = False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def is_fx_tracing():
|
| 50 |
+
return _is_fx_tracing_flag
|
| 51 |
+
|
| 52 |
+
@compatibility(is_backward_compatible=True)
|
| 53 |
+
class ProxyableClassMeta(type):
|
| 54 |
+
"""
|
| 55 |
+
ProxyableClassMeta allows you to make construction of a given Python class
|
| 56 |
+
symbolically traceable. For example::
|
| 57 |
+
|
| 58 |
+
import torch
|
| 59 |
+
import torch.fx
|
| 60 |
+
|
| 61 |
+
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
|
| 62 |
+
def __init__(self, left, right):
|
| 63 |
+
self.left, self.right = left, right
|
| 64 |
+
|
| 65 |
+
def add(self, other):
|
| 66 |
+
l = self.left + other.left
|
| 67 |
+
r = self.right + other.right
|
| 68 |
+
return TensorPair(l, r)
|
| 69 |
+
|
| 70 |
+
def mul(self, other):
|
| 71 |
+
l = self.left * other.left
|
| 72 |
+
r = self.right * other.right
|
| 73 |
+
return TensorPair(l, r)
|
| 74 |
+
|
| 75 |
+
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
|
| 76 |
+
s = x.add(TensorPair(y, y))
|
| 77 |
+
return s.mul(x)
|
| 78 |
+
|
| 79 |
+
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
| 80 |
+
y = torch.randn(5, 3)
|
| 81 |
+
ref_out = use_tensor_pair_ctor(x, y)
|
| 82 |
+
|
| 83 |
+
traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
|
| 84 |
+
print(traced.code)
|
| 85 |
+
'''
|
| 86 |
+
def forward(self, x : __main___TensorPair, y : torch.Tensor):
|
| 87 |
+
tensor_pair = __main___TensorPair(y, y); y = None
|
| 88 |
+
add = x.add(tensor_pair); tensor_pair = None
|
| 89 |
+
mul = add.mul(x); add = x = None
|
| 90 |
+
return mul
|
| 91 |
+
'''
|
| 92 |
+
|
| 93 |
+
From this example, we can see that construction of a class (``TensorPair``)
|
| 94 |
+
defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
|
| 95 |
+
tracing.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(cls, name, bases, attrs):
|
| 99 |
+
_proxyable_classes.setdefault(cls)
|
| 100 |
+
super().__init__(name, bases, attrs)
|
| 101 |
+
|
| 102 |
+
def __call__(cls, *args, **kwargs):
|
| 103 |
+
instance = cls.__new__(cls) # type: ignore[call-overload]
|
| 104 |
+
|
| 105 |
+
if not is_fx_tracing():
|
| 106 |
+
cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
|
| 107 |
+
return instance
|
| 108 |
+
|
| 109 |
+
found_proxies = []
|
| 110 |
+
|
| 111 |
+
def check_proxy(a):
|
| 112 |
+
if isinstance(a, Proxy):
|
| 113 |
+
found_proxies.append(a)
|
| 114 |
+
|
| 115 |
+
map_aggregate(args, check_proxy)
|
| 116 |
+
map_aggregate(kwargs, check_proxy)
|
| 117 |
+
|
| 118 |
+
if len(found_proxies) != 0:
|
| 119 |
+
tracer = found_proxies[0].tracer
|
| 120 |
+
return tracer.create_proxy("call_function", cls, args, kwargs)
|
| 121 |
+
else:
|
| 122 |
+
cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
|
| 123 |
+
return instance
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
|
| 127 |
+
co = fn.__code__
|
| 128 |
+
co_flags = co.co_flags & ~HAS_VARSTUFF
|
| 129 |
+
co_args: tuple
|
| 130 |
+
if hasattr(co, "co_qualname"):
|
| 131 |
+
# Python-3.11+ code signature
|
| 132 |
+
co_args = (
|
| 133 |
+
nargs,
|
| 134 |
+
0,
|
| 135 |
+
0,
|
| 136 |
+
co.co_nlocals,
|
| 137 |
+
co.co_stacksize,
|
| 138 |
+
co_flags,
|
| 139 |
+
co.co_code,
|
| 140 |
+
co.co_consts,
|
| 141 |
+
co.co_names,
|
| 142 |
+
co.co_varnames,
|
| 143 |
+
co.co_filename,
|
| 144 |
+
co.co_name,
|
| 145 |
+
co.co_qualname, # type: ignore[attr-defined]
|
| 146 |
+
co.co_firstlineno,
|
| 147 |
+
co.co_lnotab,
|
| 148 |
+
co.co_exceptiontable, # type: ignore[attr-defined]
|
| 149 |
+
co.co_freevars,
|
| 150 |
+
co.co_cellvars,
|
| 151 |
+
)
|
| 152 |
+
elif hasattr(co, "co_posonlyargcount"):
|
| 153 |
+
co_args = (
|
| 154 |
+
nargs,
|
| 155 |
+
0,
|
| 156 |
+
0,
|
| 157 |
+
co.co_nlocals,
|
| 158 |
+
co.co_stacksize,
|
| 159 |
+
co_flags,
|
| 160 |
+
co.co_code,
|
| 161 |
+
co.co_consts,
|
| 162 |
+
co.co_names,
|
| 163 |
+
co.co_varnames,
|
| 164 |
+
co.co_filename,
|
| 165 |
+
co.co_name,
|
| 166 |
+
co.co_firstlineno,
|
| 167 |
+
co.co_lnotab,
|
| 168 |
+
co.co_freevars,
|
| 169 |
+
co.co_cellvars,
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
co_args = (
|
| 173 |
+
nargs,
|
| 174 |
+
0,
|
| 175 |
+
co.co_nlocals,
|
| 176 |
+
co.co_stacksize,
|
| 177 |
+
co_flags,
|
| 178 |
+
co.co_code,
|
| 179 |
+
co.co_consts,
|
| 180 |
+
co.co_names,
|
| 181 |
+
co.co_varnames,
|
| 182 |
+
co.co_filename,
|
| 183 |
+
co.co_name,
|
| 184 |
+
co.co_firstlineno,
|
| 185 |
+
co.co_lnotab,
|
| 186 |
+
co.co_freevars,
|
| 187 |
+
co.co_cellvars,
|
| 188 |
+
)
|
| 189 |
+
new_code = CodeType(*co_args) # type: ignore[arg-type]
|
| 190 |
+
return FunctionType(
|
| 191 |
+
new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# we need to insert placeholder nodes for *args and **kwargs
|
| 195 |
+
# we can't call this function normally, otherwise it would try to unpack them
|
| 196 |
+
# instead, let's make python think that args and kwargs are normal variables
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@compatibility(is_backward_compatible=False)
|
| 200 |
+
class PHBase:
|
| 201 |
+
"""
|
| 202 |
+
Object representing an input placeholder to `concrete_args`
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __repr__(self):
|
| 206 |
+
return "PH"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
PH = PHBase()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@compatibility(is_backward_compatible=False)
|
| 213 |
+
class PHWithMeta(PHBase):
|
| 214 |
+
"""
|
| 215 |
+
Object representing an input placeholder to `concrete_args`
|
| 216 |
+
"""
|
| 217 |
+
def __init__(self, ph_key: Optional[str] = None):
|
| 218 |
+
super().__init__()
|
| 219 |
+
|
| 220 |
+
# Provide a hey for user to identify placeholder node during analysis
|
| 221 |
+
self.ph_key = ph_key
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _transfer_attrs(fr, to):
|
| 225 |
+
for attr_name in dir(fr):
|
| 226 |
+
attr_val = getattr(fr, attr_name)
|
| 227 |
+
if (
|
| 228 |
+
not callable(attr_val)
|
| 229 |
+
and not attr_name.startswith("__")
|
| 230 |
+
and not hasattr(to, attr_name)
|
| 231 |
+
):
|
| 232 |
+
setattr(to, attr_name, attr_val)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@compatibility(is_backward_compatible=True)
|
| 236 |
+
class Tracer(TracerBase):
|
| 237 |
+
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
| 238 |
+
# The first line of this docstring overrides the one Sphinx generates for the
|
| 239 |
+
# documentation. We need it so that Sphinx doesn't leak `math`s path from the
|
| 240 |
+
# build environment (e.g. `<module 'math' from '/leaked/path').
|
| 241 |
+
|
| 242 |
+
"""Tracer(autowrap_modules=(math,), autowrap_functions=())
|
| 243 |
+
|
| 244 |
+
``Tracer`` is the class that implements the symbolic tracing functionality
|
| 245 |
+
of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
|
| 246 |
+
to ``Tracer().trace(m)``.
|
| 247 |
+
|
| 248 |
+
Tracer can be subclassed to override various behaviors of the tracing
|
| 249 |
+
process. The different behaviors that can be overridden are described
|
| 250 |
+
in the docstrings of the methods on this class.
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# Not checking BC on this API because the default value for `autowrap_modules`
|
| 254 |
+
# includes the local filepath to the `math` module, which would jitter
|
| 255 |
+
# across machines.
|
| 256 |
+
@compatibility(is_backward_compatible=True)
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
autowrap_modules: Tuple[ModuleType] = (math,),
|
| 260 |
+
autowrap_functions: Tuple[Callable, ...] = (),
|
| 261 |
+
param_shapes_constant: bool = False,
|
| 262 |
+
) -> None:
|
| 263 |
+
# This method's signature is overridden by the first line of this class'
|
| 264 |
+
# docstring. If this method's signature is modified, the signature that
|
| 265 |
+
# overrides it also should be modified accordingly.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
Construct a Tracer object.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
|
| 272 |
+
autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
|
| 273 |
+
Python modules whose functions should be wrapped automatically
|
| 274 |
+
without needing to use fx.wrap(). Backward-compatibility for
|
| 275 |
+
this parameter is guaranteed.
|
| 276 |
+
|
| 277 |
+
autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
|
| 278 |
+
Python functions that should be wrapped automatically without
|
| 279 |
+
needing to use fx.wrap(). Backward compatibility for this
|
| 280 |
+
parameter is guaranteed.
|
| 281 |
+
|
| 282 |
+
param_shapes_constant (bool): When this flag is set, calls to shape,
|
| 283 |
+
size and a few other shape like attributes of a module's parameter
|
| 284 |
+
will be evaluated directly, rather than returning a new Proxy value
|
| 285 |
+
for an attribute access. Backward compatibility for this parameter
|
| 286 |
+
is guaranteed.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
super().__init__()
|
| 290 |
+
|
| 291 |
+
# Functions we will eagerly wrap when we see them while tracing
|
| 292 |
+
# this captures both `math.sqrt()` and `from math import sqrt` automatically
|
| 293 |
+
self._autowrap_function_ids: Set[int] = {
|
| 294 |
+
id(value)
|
| 295 |
+
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
|
| 296 |
+
if not name.startswith("_") and callable(value)
|
| 297 |
+
}
|
| 298 |
+
self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
|
| 299 |
+
|
| 300 |
+
# Python modules to apply autowrap to at the start, in addition to
|
| 301 |
+
# modules we see while tracing
|
| 302 |
+
self._autowrap_search: List[ModuleType] = list(autowrap_modules)
|
| 303 |
+
self.param_shapes_constant = param_shapes_constant
|
| 304 |
+
|
| 305 |
+
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
|
| 306 |
+
self.root_module_name: str = ""
|
| 307 |
+
# Maps the containing module's name to the operator name
|
| 308 |
+
self.scope = Scope("", None)
|
| 309 |
+
# Records the module call stack
|
| 310 |
+
self.module_stack = collections.OrderedDict()
|
| 311 |
+
# Mapping of node name to module scope
|
| 312 |
+
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
|
| 313 |
+
|
| 314 |
+
_qualname_counter: Dict[str, int] = collections.defaultdict(int)
|
| 315 |
+
|
| 316 |
+
@compatibility(is_backward_compatible=True)
|
| 317 |
+
def get_fresh_qualname(self, prefix: str) -> str:
|
| 318 |
+
"""
|
| 319 |
+
Gets a fresh name for a prefix and returns it. This function ensures
|
| 320 |
+
that it will not clash with an existing attribute on the graph.
|
| 321 |
+
"""
|
| 322 |
+
# The idea here is that if the module doesn't have this prefix at all we
|
| 323 |
+
# should reset the counter to start from the beginning
|
| 324 |
+
# It's a ... little bit hacky (doesn't cover all cases) but the precise
|
| 325 |
+
# naming of the prefixes isn't a correctness issue, just a niceness
|
| 326 |
+
# issue
|
| 327 |
+
qualname = f"{prefix}0"
|
| 328 |
+
if not hasattr(self.root, qualname):
|
| 329 |
+
self._qualname_counter[prefix] = 0
|
| 330 |
+
return qualname
|
| 331 |
+
|
| 332 |
+
i = self._qualname_counter[prefix]
|
| 333 |
+
while True:
|
| 334 |
+
qualname = f"{prefix}{i}"
|
| 335 |
+
i += 1
|
| 336 |
+
if not hasattr(self.root, qualname):
|
| 337 |
+
break
|
| 338 |
+
self._qualname_counter[prefix] = i
|
| 339 |
+
|
| 340 |
+
return qualname
|
| 341 |
+
|
| 342 |
+
@compatibility(is_backward_compatible=True)
|
| 343 |
+
def create_arg(self, a: Any) -> "Argument":
|
| 344 |
+
"""
|
| 345 |
+
A method to specify the behavior of tracing when preparing values to
|
| 346 |
+
be used as arguments to nodes in the ``Graph``.
|
| 347 |
+
|
| 348 |
+
By default, the behavior includes:
|
| 349 |
+
|
| 350 |
+
#. Iterate through collection types (e.g. tuple, list, dict) and recursively
|
| 351 |
+
call ``create_args`` on the elements.
|
| 352 |
+
#. Given a Proxy object, return a reference to the underlying IR ``Node``
|
| 353 |
+
#. Given a non-Proxy Tensor object, emit IR for various cases:
|
| 354 |
+
|
| 355 |
+
* For a Parameter, emit a ``get_attr`` node referring to that Parameter
|
| 356 |
+
* For a non-Parameter Tensor, store the Tensor away in a special
|
| 357 |
+
attribute referring to that attribute.
|
| 358 |
+
|
| 359 |
+
This method can be overridden to support more types.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
|
| 363 |
+
a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
|
| 368 |
+
The value ``a`` converted into the appropriate ``Argument``
|
| 369 |
+
"""
|
| 370 |
+
# The base tracer is used to construct Graphs when there is no associated
|
| 371 |
+
# module hierarchy, so it can never create parameter references.
|
| 372 |
+
# The default tracer adds the ability to refer to parameters when
|
| 373 |
+
# tracing modules.
|
| 374 |
+
if isinstance(a, torch.nn.Parameter):
|
| 375 |
+
for n, p in self.root.named_parameters():
|
| 376 |
+
if a is p:
|
| 377 |
+
return self.create_node("get_attr", n, (), {})
|
| 378 |
+
raise NameError("parameter is not a member of this module")
|
| 379 |
+
elif isinstance(a, torch.Tensor):
|
| 380 |
+
for n_, p_ in self.root.named_buffers():
|
| 381 |
+
if a is p_:
|
| 382 |
+
return self.create_node("get_attr", n_, (), {})
|
| 383 |
+
elif isinstance(a, torch.nn.Module):
|
| 384 |
+
for n_, p_ in self.root.named_modules():
|
| 385 |
+
if a is p_:
|
| 386 |
+
return self.create_node("get_attr", n_, (), {})
|
| 387 |
+
# For NamedTuple instances that appear literally as args, we emit
|
| 388 |
+
# a node to construct the NamedTuple and use that Node as the argument.
|
| 389 |
+
if isinstance(a, tuple) and hasattr(a, "_fields"):
|
| 390 |
+
args = tuple(self.create_arg(elem) for elem in a)
|
| 391 |
+
return self.create_node("call_function", a.__class__, args, {})
|
| 392 |
+
|
| 393 |
+
# Tensors do not have a reliable string repr() from which they can be
|
| 394 |
+
# constructed (and we probably don't want to rely on that, either), so
|
| 395 |
+
# for any constant Tensor values we encounter, first search for if they
|
| 396 |
+
# are an attribute of some module in the module hierarchy. If so, emit
|
| 397 |
+
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
|
| 398 |
+
# tensor value into a special attribute on the Module s.t. we can
|
| 399 |
+
# retrieve it with a get_attr.
|
| 400 |
+
if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)):
|
| 401 |
+
qualname: Optional[str] = self.tensor_attrs.get(a)
|
| 402 |
+
|
| 403 |
+
# Tensor was not found in the Module hierarchy, stow it away in a
|
| 404 |
+
# special attribute and set the qualname to refer to that
|
| 405 |
+
if not qualname:
|
| 406 |
+
base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj"
|
| 407 |
+
qualname = self.get_fresh_qualname(base_name)
|
| 408 |
+
assert isinstance(qualname, str)
|
| 409 |
+
self.tensor_attrs[a] = qualname
|
| 410 |
+
setattr(self.root, qualname, a)
|
| 411 |
+
|
| 412 |
+
return self.create_node("get_attr", qualname, (), {})
|
| 413 |
+
|
| 414 |
+
if type(a) in _proxyable_classes:
|
| 415 |
+
# This is an instance of a proxyable class for which we did not
|
| 416 |
+
# witness its construction. Intern this as a constant attribute
|
| 417 |
+
|
| 418 |
+
# TODO: binary search
|
| 419 |
+
qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
|
| 420 |
+
assert isinstance(qualname, str)
|
| 421 |
+
setattr(self.root, qualname, a)
|
| 422 |
+
|
| 423 |
+
return self.create_node("get_attr", qualname, (), {})
|
| 424 |
+
|
| 425 |
+
return super().create_arg(a)
|
| 426 |
+
|
| 427 |
+
@compatibility(is_backward_compatible=True)
|
| 428 |
+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
| 429 |
+
"""
|
| 430 |
+
A method to specify whether a given ``nn.Module`` is a "leaf" module.
|
| 431 |
+
|
| 432 |
+
Leaf modules are the atomic units that appear in
|
| 433 |
+
the IR, referenced by ``call_module`` calls. By default,
|
| 434 |
+
Modules in the PyTorch standard library namespace (torch.nn)
|
| 435 |
+
are leaf modules. All other modules are traced through and
|
| 436 |
+
their constituent ops are recorded, unless specified otherwise
|
| 437 |
+
via this parameter.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
|
| 441 |
+
m (Module): The module being queried about
|
| 442 |
+
module_qualified_name (str): The path to root of this module. For example,
|
| 443 |
+
if you have a module hierarchy where submodule ``foo`` contains
|
| 444 |
+
submodule ``bar``, which contains submodule ``baz``, that module will
|
| 445 |
+
appear with the qualified name ``foo.bar.baz`` here.
|
| 446 |
+
"""
|
| 447 |
+
return (
|
| 448 |
+
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
|
| 449 |
+
and not isinstance(m, torch.nn.Sequential)
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
@compatibility(is_backward_compatible=True)
|
| 453 |
+
def path_of_module(self, mod: torch.nn.Module) -> str:
|
| 454 |
+
"""
|
| 455 |
+
Helper method to find the qualified name of ``mod`` in the Module hierarchy
|
| 456 |
+
of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
|
| 457 |
+
a submodule named ``bar``, passing ``bar`` into this function will return
|
| 458 |
+
the string "foo.bar".
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
|
| 462 |
+
mod (str): The ``Module`` to retrieve the qualified name for.
|
| 463 |
+
"""
|
| 464 |
+
# Prefer the O(1) algorithm
|
| 465 |
+
if self.submodule_paths:
|
| 466 |
+
path = self.submodule_paths.get(mod)
|
| 467 |
+
if path is None:
|
| 468 |
+
raise NameError("module is not installed as a submodule")
|
| 469 |
+
assert isinstance(path, str)
|
| 470 |
+
return path
|
| 471 |
+
# O(N^2) fallback in the case that we didn't store the submodule
|
| 472 |
+
# paths.
|
| 473 |
+
else:
|
| 474 |
+
for n, p in self.root.named_modules():
|
| 475 |
+
if mod is p:
|
| 476 |
+
return n
|
| 477 |
+
raise NameError("module is not installed as a submodule")
|
| 478 |
+
|
| 479 |
+
@compatibility(is_backward_compatible=True)
|
| 480 |
+
def call_module(
|
| 481 |
+
self,
|
| 482 |
+
m: torch.nn.Module,
|
| 483 |
+
forward: Callable[..., Any],
|
| 484 |
+
args: Tuple[Any, ...],
|
| 485 |
+
kwargs: Dict[str, Any],
|
| 486 |
+
) -> Any:
|
| 487 |
+
"""
|
| 488 |
+
Method that specifies the behavior of this ``Tracer`` when it encounters
|
| 489 |
+
a call to an ``nn.Module`` instance.
|
| 490 |
+
|
| 491 |
+
By default, the behavior is to check if the called module is a leaf module
|
| 492 |
+
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
|
| 493 |
+
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
|
| 494 |
+
the operations in its ``forward`` function.
|
| 495 |
+
|
| 496 |
+
This method can be overridden to--for example--create nested traced
|
| 497 |
+
GraphModules, or any other behavior you would want while tracing across
|
| 498 |
+
``Module`` boundaries.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
|
| 502 |
+
m (Module): The module for which a call is being emitted
|
| 503 |
+
forward (Callable): The forward() method of the ``Module`` to be invoked
|
| 504 |
+
args (Tuple): args of the module callsite
|
| 505 |
+
kwargs (Dict): kwargs of the module callsite
|
| 506 |
+
|
| 507 |
+
Return:
|
| 508 |
+
|
| 509 |
+
The return value from the Module call. In the case that a ``call_module``
|
| 510 |
+
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
|
| 511 |
+
value was returned from the ``Module`` invocation.
|
| 512 |
+
"""
|
| 513 |
+
module_qualified_name = self.path_of_module(m)
|
| 514 |
+
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
|
| 515 |
+
# module_stack is an ordered dict so writing then deleting the
|
| 516 |
+
# entry is equivalent to push/pop on a list
|
| 517 |
+
self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
|
| 518 |
+
if not self.is_leaf_module(m, module_qualified_name):
|
| 519 |
+
ret_val = forward(*args, **kwargs)
|
| 520 |
+
else:
|
| 521 |
+
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
|
| 522 |
+
key, _ = self.module_stack.popitem(last=True)
|
| 523 |
+
assert key == _scope.module_path, f" Unexpected key {key}"
|
| 524 |
+
|
| 525 |
+
return ret_val
|
| 526 |
+
|
| 527 |
+
@compatibility(is_backward_compatible=False)
|
| 528 |
+
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
| 529 |
+
"""
|
| 530 |
+
Method that specifies the behavior of this ``Tracer`` when we call getattr
|
| 531 |
+
on a call to an ``nn.Module`` instance.
|
| 532 |
+
|
| 533 |
+
By default, the behavior is to return a proxy value for the attribute. It
|
| 534 |
+
also stores the proxy value in the ``parameter_proxy_cache``, so that future
|
| 535 |
+
calls will reuse the proxy rather than creating a new one.
|
| 536 |
+
|
| 537 |
+
This method can be overridden to --for example-- not return proxies when
|
| 538 |
+
querying parameters.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
|
| 542 |
+
attr (str): The name of the attribute being queried
|
| 543 |
+
attr_val (Any): The value of the attribute
|
| 544 |
+
parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
|
| 545 |
+
|
| 546 |
+
Return:
|
| 547 |
+
|
| 548 |
+
The return value from the getattr call.
|
| 549 |
+
"""
|
| 550 |
+
def maybe_get_proxy_for_attr(
|
| 551 |
+
attr_val, collection_to_search, parameter_proxy_cache
|
| 552 |
+
):
|
| 553 |
+
for n, p in collection_to_search:
|
| 554 |
+
if attr_val is p:
|
| 555 |
+
if n not in parameter_proxy_cache:
|
| 556 |
+
kwargs = {}
|
| 557 |
+
if (
|
| 558 |
+
"proxy_factory_fn"
|
| 559 |
+
in inspect.signature(self.create_proxy).parameters
|
| 560 |
+
):
|
| 561 |
+
kwargs["proxy_factory_fn"] = (
|
| 562 |
+
None
|
| 563 |
+
if not self.param_shapes_constant
|
| 564 |
+
else lambda node: ParameterProxy(
|
| 565 |
+
self, node, n, attr_val
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
| 569 |
+
parameter_proxy_cache[n] = val_proxy
|
| 570 |
+
return parameter_proxy_cache[n]
|
| 571 |
+
return None
|
| 572 |
+
|
| 573 |
+
if isinstance(attr_val, torch.nn.Parameter):
|
| 574 |
+
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
| 575 |
+
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
| 576 |
+
)
|
| 577 |
+
if maybe_parameter_proxy is not None:
|
| 578 |
+
return maybe_parameter_proxy
|
| 579 |
+
|
| 580 |
+
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
|
| 581 |
+
maybe_buffer_proxy = maybe_get_proxy_for_attr(
|
| 582 |
+
attr_val, self.root.named_buffers(), parameter_proxy_cache
|
| 583 |
+
)
|
| 584 |
+
if maybe_buffer_proxy is not None:
|
| 585 |
+
return maybe_buffer_proxy
|
| 586 |
+
|
| 587 |
+
return attr_val
|
| 588 |
+
|
| 589 |
+
# This method will be refactored
|
| 590 |
+
@compatibility(is_backward_compatible=False)
|
| 591 |
+
def create_args_for_root(self, root_fn, is_module, concrete_args=None):
|
| 592 |
+
"""
|
| 593 |
+
Create ``placeholder`` nodes corresponding to the signature of the ``root``
|
| 594 |
+
Module. This method introspects root's signature and emits those
|
| 595 |
+
nodes accordingly, also supporting ``*args`` and ``**kwargs``.
|
| 596 |
+
"""
|
| 597 |
+
# In some cases, a function or method has been decorated with a wrapper
|
| 598 |
+
# defined via ``functools.wraps``. In this case, the outer code object
|
| 599 |
+
# will likely not contain the actual parameters we care about, so unwrap
|
| 600 |
+
# the function to get to the innermost callable.
|
| 601 |
+
fn_for_analysis = inspect.unwrap(root_fn)
|
| 602 |
+
co = fn_for_analysis.__code__
|
| 603 |
+
total_args = co.co_argcount + co.co_kwonlyargcount
|
| 604 |
+
orig_args = list(co.co_varnames)
|
| 605 |
+
names_iter = iter(co.co_varnames)
|
| 606 |
+
args: List[Any] = []
|
| 607 |
+
skip_arg_idx = 0
|
| 608 |
+
if is_module:
|
| 609 |
+
if total_args == 0:
|
| 610 |
+
raise RuntimeError(
|
| 611 |
+
"``self`` argument cannot be part of *args expansion!"
|
| 612 |
+
)
|
| 613 |
+
skip_arg_idx = 1
|
| 614 |
+
next(names_iter) # skip self
|
| 615 |
+
args.append(self.root)
|
| 616 |
+
|
| 617 |
+
sig = inspect.signature(fn_for_analysis)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# This covers the very specific case where we are passing in flat
|
| 621 |
+
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
|
| 622 |
+
# In this case, just take the concrete_args and pass them through.
|
| 623 |
+
name_idx = 0
|
| 624 |
+
if isinstance(concrete_args, tuple) and \
|
| 625 |
+
len(concrete_args) > 0 and \
|
| 626 |
+
(co.co_flags & HAS_VARSTUFF) and \
|
| 627 |
+
total_args == 1:
|
| 628 |
+
for concrete_arg in concrete_args:
|
| 629 |
+
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
|
| 630 |
+
if isinstance(concrete_arg, PHBase):
|
| 631 |
+
if concrete_arg != PH:
|
| 632 |
+
# Transfer attrs in the case where you're using a placeholder other
|
| 633 |
+
# than the singleton PH (PH has no attributes to transfer).
|
| 634 |
+
# Proxies were created out of the placeholders.
|
| 635 |
+
# Transfer any metadata (put on the placeholders in the form of
|
| 636 |
+
# attributes set by the user) from the placeholder to the
|
| 637 |
+
# underlying nodes (the proxy is unwrapped by the user, but
|
| 638 |
+
# the metadata should hold).
|
| 639 |
+
_transfer_attrs(fr=concrete_arg, to=out.node)
|
| 640 |
+
args.append(out)
|
| 641 |
+
name_idx += 1
|
| 642 |
+
return root_fn, args
|
| 643 |
+
|
| 644 |
+
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
|
| 645 |
+
if isinstance(concrete_args, tuple):
|
| 646 |
+
if len(arg_names) != len(concrete_args):
|
| 647 |
+
raise RuntimeError(
|
| 648 |
+
f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
|
| 649 |
+
)
|
| 650 |
+
concrete_args = dict(zip(arg_names, concrete_args))
|
| 651 |
+
|
| 652 |
+
def proxy_placeholder(name):
|
| 653 |
+
return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
|
| 654 |
+
|
| 655 |
+
args.extend(proxy_placeholder(names) for names in arg_names)
|
| 656 |
+
|
| 657 |
+
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
|
| 658 |
+
# TODO: type annotations for *args and **kwargs
|
| 659 |
+
if co.co_flags & inspect.CO_VARARGS:
|
| 660 |
+
args.append(proxy_placeholder("*" + next(names_iter)))
|
| 661 |
+
if co.co_flags & inspect.CO_VARKEYWORDS:
|
| 662 |
+
args.append(proxy_placeholder("**" + next(names_iter)))
|
| 663 |
+
root_fn = _patch_function(root_fn, len(args))
|
| 664 |
+
|
| 665 |
+
flat_args, in_spec = pytree.tree_flatten(tuple(args))
|
| 666 |
+
if not all(child.is_leaf() for child in in_spec.children_specs):
|
| 667 |
+
# In the case that we have pytree-flattened inputs in
|
| 668 |
+
# `concrete_args`, generate a flattening wrapper around the
|
| 669 |
+
# original root function and return that.
|
| 670 |
+
self.graph._codegen = _PyTreeCodeGen(
|
| 671 |
+
_PyTreeInfo(orig_args[:total_args], in_spec, None)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
def flatten_fn(*args):
|
| 675 |
+
tree_args = pytree.tree_unflatten(list(args), in_spec)
|
| 676 |
+
tree_out = root_fn(*tree_args)
|
| 677 |
+
out_args, out_spec = pytree.tree_flatten(tree_out)
|
| 678 |
+
assert isinstance(self.graph._codegen, _PyTreeCodeGen)
|
| 679 |
+
self.graph._codegen.pytree_info = (
|
| 680 |
+
self.graph._codegen.pytree_info._replace(out_spec=out_spec)
|
| 681 |
+
)
|
| 682 |
+
return out_args
|
| 683 |
+
|
| 684 |
+
return flatten_fn, flat_args
|
| 685 |
+
return root_fn, args
|
| 686 |
+
|
| 687 |
+
@compatibility(is_backward_compatible=True)
|
| 688 |
+
def trace(
|
| 689 |
+
self,
|
| 690 |
+
root: Union[torch.nn.Module, Callable[..., Any]],
|
| 691 |
+
concrete_args: Optional[Dict[str, Any]] = None,
|
| 692 |
+
) -> Graph:
|
| 693 |
+
"""
|
| 694 |
+
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
|
| 695 |
+
can either be an ``nn.Module`` instance or a Python callable.
|
| 696 |
+
|
| 697 |
+
Note that after this call, ``self.root`` may be different from the ``root`` passed
|
| 698 |
+
in here. For example, when a free function is passed to ``trace()``, we will
|
| 699 |
+
create an ``nn.Module`` instance to use as the root and add embedded constants
|
| 700 |
+
to.
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
Args:
|
| 704 |
+
|
| 705 |
+
root (Union[Module, Callable]): Either a ``Module`` or a function to be
|
| 706 |
+
traced through. Backwards-compatibility for this parameter is
|
| 707 |
+
guaranteed.
|
| 708 |
+
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
|
| 709 |
+
not be treated as Proxies. This parameter is experimental and
|
| 710 |
+
its backwards-compatibility is *NOT* guaranteed.
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
|
| 714 |
+
A ``Graph`` representing the semantics of the passed-in ``root``.
|
| 715 |
+
"""
|
| 716 |
+
global _is_fx_tracing_flag
|
| 717 |
+
old_is_fx_tracing_flag = _is_fx_tracing_flag
|
| 718 |
+
_is_fx_tracing_flag = True
|
| 719 |
+
try:
|
| 720 |
+
if isinstance(root, torch.nn.Module):
|
| 721 |
+
|
| 722 |
+
# do real recompilation for _LazyGraphModule before retracing since the trace
|
| 723 |
+
# method can not trace the _lazy_forward method. Got error:
|
| 724 |
+
# https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
|
| 725 |
+
# without this.
|
| 726 |
+
from torch.fx._lazy_graph_module import _LazyGraphModule
|
| 727 |
+
_LazyGraphModule.force_recompile(root)
|
| 728 |
+
|
| 729 |
+
self.root = root
|
| 730 |
+
|
| 731 |
+
assert hasattr(
|
| 732 |
+
type(root), self.traced_func_name
|
| 733 |
+
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
|
| 734 |
+
|
| 735 |
+
fn = getattr(type(root), self.traced_func_name)
|
| 736 |
+
self.root_module_name = root._get_name()
|
| 737 |
+
self.submodule_paths = {mod: name for name, mod in root.named_modules()}
|
| 738 |
+
else:
|
| 739 |
+
self.root = torch.nn.Module()
|
| 740 |
+
fn = root
|
| 741 |
+
|
| 742 |
+
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
|
| 743 |
+
self.graph = Graph(tracer_cls=tracer_cls)
|
| 744 |
+
if hasattr(fn, '__code__'):
|
| 745 |
+
code = fn.__code__
|
| 746 |
+
self.graph._co_fields = {
|
| 747 |
+
'co_name': code.co_name,
|
| 748 |
+
'co_filename': code.co_filename,
|
| 749 |
+
'co_firstlineno': code.co_firstlineno,
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
# When we encounter a Tensor value that's not a parameter, we look if it
|
| 753 |
+
# is some other attribute on the model. Construct a dict mapping Tensor
|
| 754 |
+
# values to the qualified name here for efficiency. This is used downstream
|
| 755 |
+
# in create_arg
|
| 756 |
+
self.tensor_attrs: Dict[
|
| 757 |
+
Union[
|
| 758 |
+
torch.Tensor,
|
| 759 |
+
ScriptObject,
|
| 760 |
+
FakeScriptObject
|
| 761 |
+
], str
|
| 762 |
+
] = {}
|
| 763 |
+
|
| 764 |
+
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
|
| 765 |
+
for k, v in m.__dict__.items():
|
| 766 |
+
if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
|
| 767 |
+
self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
|
| 768 |
+
for k, v in m.named_children():
|
| 769 |
+
collect_tensor_attrs(v, prefix_atoms + [k])
|
| 770 |
+
|
| 771 |
+
collect_tensor_attrs(self.root, [])
|
| 772 |
+
|
| 773 |
+
assert isinstance(fn, FunctionType)
|
| 774 |
+
|
| 775 |
+
fn_globals = fn.__globals__ # run before it gets patched
|
| 776 |
+
fn, args = self.create_args_for_root(
|
| 777 |
+
fn, isinstance(root, torch.nn.Module), concrete_args
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
parameter_proxy_cache: Dict[
|
| 781 |
+
str, Proxy
|
| 782 |
+
] = {} # Reduce number of get_attr calls
|
| 783 |
+
|
| 784 |
+
# Method dispatch on parameters is not recorded unless it's directly used.
|
| 785 |
+
# Thus, we need to insert a proxy when __getattr__ requests a parameter.
|
| 786 |
+
@functools.wraps(_orig_module_getattr)
|
| 787 |
+
def module_getattr_wrapper(mod, attr):
|
| 788 |
+
attr_val = _orig_module_getattr(mod, attr)
|
| 789 |
+
return self.getattr(attr, attr_val, parameter_proxy_cache)
|
| 790 |
+
|
| 791 |
+
@functools.wraps(_orig_module_call)
|
| 792 |
+
def module_call_wrapper(mod, *args, **kwargs):
|
| 793 |
+
def forward(*args, **kwargs):
|
| 794 |
+
return _orig_module_call(mod, *args, **kwargs)
|
| 795 |
+
|
| 796 |
+
_autowrap_check(
|
| 797 |
+
patcher, # type: ignore[has-type]
|
| 798 |
+
getattr(getattr(mod, "forward", mod), "__globals__", {}),
|
| 799 |
+
self._autowrap_function_ids,
|
| 800 |
+
)
|
| 801 |
+
return self.call_module(mod, forward, args, kwargs)
|
| 802 |
+
|
| 803 |
+
with _new_patcher() as patcher:
|
| 804 |
+
# allow duplicate patches to support the case of nested calls
|
| 805 |
+
patcher.patch_method(
|
| 806 |
+
torch.nn.Module,
|
| 807 |
+
"__getattr__",
|
| 808 |
+
module_getattr_wrapper,
|
| 809 |
+
deduplicate=False,
|
| 810 |
+
)
|
| 811 |
+
patcher.patch_method(
|
| 812 |
+
torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
|
| 813 |
+
)
|
| 814 |
+
_patch_wrapped_functions(patcher)
|
| 815 |
+
_autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
|
| 816 |
+
for module in self._autowrap_search:
|
| 817 |
+
_autowrap_check(
|
| 818 |
+
patcher, module.__dict__, self._autowrap_function_ids
|
| 819 |
+
)
|
| 820 |
+
self.create_node(
|
| 821 |
+
"output",
|
| 822 |
+
"output",
|
| 823 |
+
(self.create_arg(fn(*args)),),
|
| 824 |
+
{},
|
| 825 |
+
type_expr=fn.__annotations__.get("return", None),
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
self.submodule_paths = None
|
| 829 |
+
finally:
|
| 830 |
+
_is_fx_tracing_flag = old_is_fx_tracing_flag
|
| 831 |
+
return self.graph
|
| 832 |
+
|
| 833 |
+
def __deepcopy__(self, memo):
|
| 834 |
+
# _autowrap_search contains modules, which cannot be deepcopied.
|
| 835 |
+
new_tracer = Tracer.__new__(Tracer)
|
| 836 |
+
|
| 837 |
+
for k, v in self.__dict__.items():
|
| 838 |
+
if k in {'_autowrap_search'}:
|
| 839 |
+
new_obj = copy.copy(v)
|
| 840 |
+
else:
|
| 841 |
+
new_obj = copy.deepcopy(v, memo)
|
| 842 |
+
|
| 843 |
+
new_tracer.__dict__[k] = new_obj
|
| 844 |
+
|
| 845 |
+
return new_tracer
|
| 846 |
+
|
| 847 |
+
def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
|
| 848 |
+
if concrete_args is not None and name in concrete_args:
|
| 849 |
+
cnt = 0
|
| 850 |
+
|
| 851 |
+
def replace_ph(x):
|
| 852 |
+
nonlocal cnt
|
| 853 |
+
cnt += 1
|
| 854 |
+
param = sig.parameters[name]
|
| 855 |
+
default = (
|
| 856 |
+
()
|
| 857 |
+
if param.default is inspect.Parameter.empty
|
| 858 |
+
else (param.default,)
|
| 859 |
+
)
|
| 860 |
+
out = self.create_proxy(
|
| 861 |
+
"placeholder", f"{name}_{str(cnt)}", default, {}
|
| 862 |
+
)
|
| 863 |
+
if isinstance(x, PHBase):
|
| 864 |
+
if x != PH:
|
| 865 |
+
# Transfer attrs in the case where you're using a placeholder other
|
| 866 |
+
# than the singleton PH (PH has no attributes to transfer).
|
| 867 |
+
# Proxies were created out of the placeholders.
|
| 868 |
+
# Transfer any metadata (put on the placeholders in the form of
|
| 869 |
+
# attributes set by the user) from the placeholder to the
|
| 870 |
+
# underlying nodes (the proxy is unwrapped by the user, but
|
| 871 |
+
# the metadata should hold).
|
| 872 |
+
_transfer_attrs(fr=x, to=out.node)
|
| 873 |
+
|
| 874 |
+
return out
|
| 875 |
+
# Union[int, bool] == bool in Python <= 3.6
|
| 876 |
+
if (
|
| 877 |
+
type(x) == bool
|
| 878 |
+
or type(x) in base_types
|
| 879 |
+
and type(x) != torch.Tensor
|
| 880 |
+
):
|
| 881 |
+
torch._assert(
|
| 882 |
+
out == x,
|
| 883 |
+
f"{name} has been specialized to have value {x} but got another value",
|
| 884 |
+
)
|
| 885 |
+
elif x is None:
|
| 886 |
+
args = (
|
| 887 |
+
out,
|
| 888 |
+
f"{name} has been specialized to have value None but got another value",
|
| 889 |
+
)
|
| 890 |
+
self.create_proxy("call_function", _assert_is_none, args, {})
|
| 891 |
+
else:
|
| 892 |
+
warnings.warn(
|
| 893 |
+
f"Was not able to add assertion to guarantee correct input {name} to "
|
| 894 |
+
f"specialized function. It is up to the user to make sure that your inputs match the "
|
| 895 |
+
f"inputs you specialized the function with."
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
return x
|
| 899 |
+
|
| 900 |
+
return pytree.tree_map(replace_ph, concrete_args[name])
|
| 901 |
+
if name[0] == "*":
|
| 902 |
+
default = ()
|
| 903 |
+
else:
|
| 904 |
+
param = sig.parameters[name]
|
| 905 |
+
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
|
| 906 |
+
return self.create_proxy(
|
| 907 |
+
"placeholder",
|
| 908 |
+
name,
|
| 909 |
+
default,
|
| 910 |
+
{},
|
| 911 |
+
type_expr=fn_for_analysis.__annotations__.get(name, None)
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
# Dictionary of (id(globals dict), function name) => globals_dict to patch for
|
| 916 |
+
# the purposes of the wrap() API.
|
| 917 |
+
# We key by the globals dict id and function name to ensure we're wrapping a given
|
| 918 |
+
# function only once.
|
| 919 |
+
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
|
| 920 |
+
|
| 921 |
+
# List of methods on classes to wrap (class type, function name)
|
| 922 |
+
# this currently only works for Tensor.* methods that aren't traced properly
|
| 923 |
+
_wrapped_methods_to_patch: List[Tuple[type, str]] = []
|
| 924 |
+
|
| 925 |
+
if os.environ.get("FX_PATCH_GETITEM") == "1":
|
| 926 |
+
# This change is needed to trace models like PositionalEmbedding from BERT:
|
| 927 |
+
# https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
|
| 928 |
+
# but causes issues in quantization documented here:
|
| 929 |
+
# https://github.com/pytorch/pytorch/issues/50710
|
| 930 |
+
# once that is fixed we can make this the default behavior.
|
| 931 |
+
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def _find_proxy(*objects_to_search):
|
| 935 |
+
"""
|
| 936 |
+
Recursively search a data structure for a Proxy() and return it,
|
| 937 |
+
return None if not found.
|
| 938 |
+
"""
|
| 939 |
+
proxy = None
|
| 940 |
+
|
| 941 |
+
def find_proxy(x):
|
| 942 |
+
nonlocal proxy
|
| 943 |
+
if isinstance(x, Proxy):
|
| 944 |
+
proxy = x
|
| 945 |
+
|
| 946 |
+
map_aggregate(objects_to_search, find_proxy)
|
| 947 |
+
return proxy
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def _create_wrapped_func(orig_fn):
|
| 951 |
+
@functools.wraps(orig_fn)
|
| 952 |
+
def wrapped(*args, **kwargs):
|
| 953 |
+
"""
|
| 954 |
+
Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
|
| 955 |
+
a Proxy object. If there is one, emit a ``call_function`` node to preserve the
|
| 956 |
+
call to this leaf function directly. Otherwise, just return the results of
|
| 957 |
+
this function call, as this function is not being traced.
|
| 958 |
+
"""
|
| 959 |
+
proxy = _find_proxy(args, kwargs)
|
| 960 |
+
if proxy is not None:
|
| 961 |
+
return_proxy = proxy.tracer.create_proxy(
|
| 962 |
+
"call_function", orig_fn, args, kwargs
|
| 963 |
+
)
|
| 964 |
+
return_proxy.node.meta["is_wrapped"] = True
|
| 965 |
+
return return_proxy
|
| 966 |
+
return orig_fn(*args, **kwargs)
|
| 967 |
+
|
| 968 |
+
return wrapped
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def _create_wrapped_method(cls, name):
|
| 972 |
+
orig_fn = getattr(cls, name)
|
| 973 |
+
|
| 974 |
+
@functools.wraps(orig_fn)
|
| 975 |
+
def wrapped(*args, **kwargs):
|
| 976 |
+
"""
|
| 977 |
+
Search the args and kwargs for a Proxy object. If there is one,
|
| 978 |
+
emit a ``call_method`` node to preserve the call to this method
|
| 979 |
+
directly. Otherwise, just return the results of this function
|
| 980 |
+
call, as this function is not being traced.
|
| 981 |
+
"""
|
| 982 |
+
proxy = _find_proxy(args, kwargs)
|
| 983 |
+
if proxy is not None:
|
| 984 |
+
return proxy.tracer.create_proxy("call_method", name, args, kwargs)
|
| 985 |
+
return orig_fn(*args, **kwargs)
|
| 986 |
+
|
| 987 |
+
return wrapped
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
class _PatchedFn(NamedTuple):
|
| 991 |
+
frame_dict: Any
|
| 992 |
+
fn_name: str
|
| 993 |
+
orig_fn: Any
|
| 994 |
+
new_fn: Any
|
| 995 |
+
|
| 996 |
+
def revert(self):
|
| 997 |
+
raise NotImplementedError
|
| 998 |
+
|
| 999 |
+
def patch(self):
|
| 1000 |
+
raise NotImplementedError
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
class _PatchedFnSetItem(_PatchedFn):
|
| 1004 |
+
def revert(self):
|
| 1005 |
+
self.frame_dict[self.fn_name] = self.orig_fn
|
| 1006 |
+
|
| 1007 |
+
def patch(self):
|
| 1008 |
+
self.frame_dict[self.fn_name] = self.new_fn
|
| 1009 |
+
|
| 1010 |
+
class _PatchedFnDel(_PatchedFn):
|
| 1011 |
+
def revert(self):
|
| 1012 |
+
del self.frame_dict[self.fn_name]
|
| 1013 |
+
|
| 1014 |
+
def patch(self):
|
| 1015 |
+
self.frame_dict[self.fn_name] = self.new_fn
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
class _PatchedFnSetAttr(_PatchedFn):
|
| 1019 |
+
def revert(self):
|
| 1020 |
+
setattr(self.frame_dict, self.fn_name, self.orig_fn)
|
| 1021 |
+
|
| 1022 |
+
def patch(self):
|
| 1023 |
+
setattr(self.frame_dict, self.fn_name, self.new_fn)
|
| 1024 |
+
|
| 1025 |
+
class _Patcher:
|
| 1026 |
+
def __init__(self) -> None:
|
| 1027 |
+
super().__init__()
|
| 1028 |
+
self.patches_made: List[_PatchedFn] = []
|
| 1029 |
+
self.visited: Set[int] = set()
|
| 1030 |
+
|
| 1031 |
+
def patch(
|
| 1032 |
+
self,
|
| 1033 |
+
frame_dict: Dict[str, Any],
|
| 1034 |
+
name: str,
|
| 1035 |
+
new_fn: Callable,
|
| 1036 |
+
deduplicate: bool = True,
|
| 1037 |
+
):
|
| 1038 |
+
"""
|
| 1039 |
+
Replace frame_dict[name] with new_fn until we exit the context manager.
|
| 1040 |
+
"""
|
| 1041 |
+
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
|
| 1042 |
+
if name not in frame_dict and hasattr(builtins, name):
|
| 1043 |
+
self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
|
| 1044 |
+
self.patches_made[-1].patch()
|
| 1045 |
+
elif getattr(frame_dict[name], "__fx_already_patched", False):
|
| 1046 |
+
return # already patched, no need to do it again
|
| 1047 |
+
else:
|
| 1048 |
+
self.patches_made.append(
|
| 1049 |
+
_PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
|
| 1050 |
+
)
|
| 1051 |
+
self.patches_made[-1].patch()
|
| 1052 |
+
|
| 1053 |
+
def patch_method(
|
| 1054 |
+
self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
|
| 1055 |
+
):
|
| 1056 |
+
"""
|
| 1057 |
+
Replace object_or_dict.name with new_fn until we exit the context manager.
|
| 1058 |
+
"""
|
| 1059 |
+
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
|
| 1060 |
+
orig_fn = getattr(cls, name)
|
| 1061 |
+
if getattr(orig_fn, "__fx_already_patched", False):
|
| 1062 |
+
return # already patched, no need to do it again
|
| 1063 |
+
self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
|
| 1064 |
+
self.patches_made[-1].patch()
|
| 1065 |
+
|
| 1066 |
+
def visit_once(self, thing: Any):
|
| 1067 |
+
"""Return True on the first call to with thing, otherwise false"""
|
| 1068 |
+
idx = id(thing)
|
| 1069 |
+
if idx in self.visited:
|
| 1070 |
+
return False
|
| 1071 |
+
self.visited.add(idx)
|
| 1072 |
+
return True
|
| 1073 |
+
|
| 1074 |
+
def revert_all_patches(self):
|
| 1075 |
+
"""
|
| 1076 |
+
Remove all the stored patcheds. It doesn't modify patches_made.
|
| 1077 |
+
"""
|
| 1078 |
+
for patch in self.patches_made:
|
| 1079 |
+
patch.revert()
|
| 1080 |
+
return self.patches_made
|
| 1081 |
+
|
| 1082 |
+
def reapply_all_patches(self):
|
| 1083 |
+
"""
|
| 1084 |
+
Patch all the stored patcheds. It doesn't modify patches_made.
|
| 1085 |
+
"""
|
| 1086 |
+
for patch in self.patches_made:
|
| 1087 |
+
patch.patch()
|
| 1088 |
+
return self.patches_made
|
| 1089 |
+
|
| 1090 |
+
def __enter__(self):
|
| 1091 |
+
return self
|
| 1092 |
+
|
| 1093 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 1094 |
+
"""
|
| 1095 |
+
Undo all the changes made via self.patch() and self.patch_method()
|
| 1096 |
+
"""
|
| 1097 |
+
while self.patches_made:
|
| 1098 |
+
# unpatch in reverse order to handle duplicates correctly
|
| 1099 |
+
self.patches_made.pop().revert()
|
| 1100 |
+
self.visited.clear()
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
CURRENT_PATCHER: Optional[_Patcher] = None
|
| 1104 |
+
|
| 1105 |
+
@contextlib.contextmanager
|
| 1106 |
+
def _new_patcher():
|
| 1107 |
+
global CURRENT_PATCHER
|
| 1108 |
+
prior_patcher = CURRENT_PATCHER
|
| 1109 |
+
try:
|
| 1110 |
+
CURRENT_PATCHER = _Patcher()
|
| 1111 |
+
yield CURRENT_PATCHER
|
| 1112 |
+
finally:
|
| 1113 |
+
# Clear all the patches made by when using current patcher.
|
| 1114 |
+
assert CURRENT_PATCHER is not None
|
| 1115 |
+
CURRENT_PATCHER.revert_all_patches()
|
| 1116 |
+
CURRENT_PATCHER = prior_patcher
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
@contextlib.contextmanager
|
| 1120 |
+
def _maybe_revert_all_patches():
|
| 1121 |
+
current_patcher = CURRENT_PATCHER
|
| 1122 |
+
patches_made = None
|
| 1123 |
+
patches_removed = None
|
| 1124 |
+
try:
|
| 1125 |
+
if current_patcher is not None:
|
| 1126 |
+
patches_removed = current_patcher.revert_all_patches()
|
| 1127 |
+
yield
|
| 1128 |
+
finally:
|
| 1129 |
+
if current_patcher is not None:
|
| 1130 |
+
patches_made = current_patcher.reapply_all_patches()
|
| 1131 |
+
assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches"
|
| 1132 |
+
|
| 1133 |
+
def _patch_wrapped_functions(patcher: _Patcher):
|
| 1134 |
+
"""
|
| 1135 |
+
Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
|
| 1136 |
+
the listed global functions in the `_create_wrapped_func` wrapper.
|
| 1137 |
+
"""
|
| 1138 |
+
for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
|
| 1139 |
+
if name not in frame_dict and hasattr(builtins, name):
|
| 1140 |
+
orig_fn = getattr(builtins, name)
|
| 1141 |
+
else:
|
| 1142 |
+
orig_fn = frame_dict[name]
|
| 1143 |
+
patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
|
| 1144 |
+
|
| 1145 |
+
for cls, name in _wrapped_methods_to_patch:
|
| 1146 |
+
patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
def _autowrap_check(
|
| 1150 |
+
patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
|
| 1151 |
+
):
|
| 1152 |
+
"""
|
| 1153 |
+
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
|
| 1154 |
+
This method searches a scope for them and patches them if found.
|
| 1155 |
+
"""
|
| 1156 |
+
if patcher.visit_once(frame_dict):
|
| 1157 |
+
for name, value in frame_dict.items():
|
| 1158 |
+
if (
|
| 1159 |
+
not name.startswith("_")
|
| 1160 |
+
and callable(value)
|
| 1161 |
+
and id(value) in function_ids
|
| 1162 |
+
):
|
| 1163 |
+
patcher.patch(frame_dict, name, _create_wrapped_func(value))
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
@compatibility(is_backward_compatible=True)
|
| 1167 |
+
def wrap(fn_or_name: Union[str, Callable]):
|
| 1168 |
+
"""
|
| 1169 |
+
This function can be called at module-level scope to register fn_or_name as a "leaf function".
|
| 1170 |
+
A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
|
| 1171 |
+
traced through::
|
| 1172 |
+
|
| 1173 |
+
# foo/bar/baz.py
|
| 1174 |
+
def my_custom_function(x, y):
|
| 1175 |
+
return x * x + y * y
|
| 1176 |
+
|
| 1177 |
+
torch.fx.wrap('my_custom_function')
|
| 1178 |
+
|
| 1179 |
+
def fn_to_be_traced(x, y):
|
| 1180 |
+
# When symbolic tracing, the below call to my_custom_function will be inserted into
|
| 1181 |
+
# the graph rather than tracing it.
|
| 1182 |
+
return my_custom_function(x, y)
|
| 1183 |
+
|
| 1184 |
+
This function can also equivalently be used as a decorator::
|
| 1185 |
+
|
| 1186 |
+
# foo/bar/baz.py
|
| 1187 |
+
@torch.fx.wrap
|
| 1188 |
+
def my_custom_function(x, y):
|
| 1189 |
+
return x * x + y * y
|
| 1190 |
+
|
| 1191 |
+
A wrapped function can be thought of a "leaf function", analogous to the concept of
|
| 1192 |
+
"leaf modules", that is, they are functions that are left as calls in the FX trace
|
| 1193 |
+
rather than traced through.
|
| 1194 |
+
|
| 1195 |
+
Args:
|
| 1196 |
+
|
| 1197 |
+
fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
|
| 1198 |
+
graph when it's called
|
| 1199 |
+
"""
|
| 1200 |
+
if not callable(fn_or_name) and not isinstance(fn_or_name, str):
|
| 1201 |
+
raise RuntimeError(
|
| 1202 |
+
"Unsupported type for global function! Must be either a callable or "
|
| 1203 |
+
"string name"
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
if callable(fn_or_name):
|
| 1207 |
+
assert not isinstance(fn_or_name, str) # to make mypy happy
|
| 1208 |
+
fn_name = fn_or_name.__name__
|
| 1209 |
+
else:
|
| 1210 |
+
assert isinstance(
|
| 1211 |
+
fn_or_name, str
|
| 1212 |
+
), "fn_or_name must be a global function or string name"
|
| 1213 |
+
fn_name = fn_or_name
|
| 1214 |
+
|
| 1215 |
+
currentframe = inspect.currentframe()
|
| 1216 |
+
assert currentframe is not None
|
| 1217 |
+
f = currentframe.f_back
|
| 1218 |
+
assert f is not None
|
| 1219 |
+
if f.f_code.co_name != "<module>":
|
| 1220 |
+
raise NotImplementedError("wrap must be called at the top level of a module")
|
| 1221 |
+
|
| 1222 |
+
# consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
|
| 1223 |
+
# semantics would be slightly different, but would add support `from x import wrapped_function`
|
| 1224 |
+
_wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
|
| 1225 |
+
return fn_or_name
|
| 1226 |
+
|
| 1227 |
+
|
| 1228 |
+
@compatibility(is_backward_compatible=True)
|
| 1229 |
+
def symbolic_trace(
|
| 1230 |
+
root: Union[torch.nn.Module, Callable[..., Any]],
|
| 1231 |
+
concrete_args: Optional[Dict[str, Any]] = None,
|
| 1232 |
+
) -> GraphModule:
|
| 1233 |
+
"""
|
| 1234 |
+
Symbolic tracing API
|
| 1235 |
+
|
| 1236 |
+
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
|
| 1237 |
+
constructed by recording operations seen while tracing through ``root``.
|
| 1238 |
+
|
| 1239 |
+
``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
|
| 1240 |
+
|
| 1241 |
+
For example::
|
| 1242 |
+
|
| 1243 |
+
def f(a, b):
|
| 1244 |
+
if b == True:
|
| 1245 |
+
return a
|
| 1246 |
+
else:
|
| 1247 |
+
return a*2
|
| 1248 |
+
|
| 1249 |
+
FX can typically not trace through this due to the presence of control
|
| 1250 |
+
flow. However, we can use `concrete_args` to specialize on the value of
|
| 1251 |
+
`b` to trace through this::
|
| 1252 |
+
|
| 1253 |
+
f = fx.symbolic_trace(f, concrete_args={'b': False})
|
| 1254 |
+
assert f(3, False) == 6
|
| 1255 |
+
|
| 1256 |
+
Note that although you can still pass in different values of `b`, they will be ignored.
|
| 1257 |
+
|
| 1258 |
+
We can also use `concrete_args` to eliminate data-structure handling from
|
| 1259 |
+
our function. This will use pytrees to flatten your input. To avoid
|
| 1260 |
+
overspecializing, pass in `fx.PH` for values that shouldn't be
|
| 1261 |
+
specialized. For example::
|
| 1262 |
+
|
| 1263 |
+
def f(x):
|
| 1264 |
+
out = 0
|
| 1265 |
+
for v in x.values():
|
| 1266 |
+
out += v
|
| 1267 |
+
return out
|
| 1268 |
+
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
|
| 1269 |
+
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
Args:
|
| 1273 |
+
root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
|
| 1274 |
+
into a Graph representation.
|
| 1275 |
+
concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
|
| 1276 |
+
|
| 1277 |
+
Returns:
|
| 1278 |
+
GraphModule: a Module created from the recorded operations from ``root``.
|
| 1279 |
+
"""
|
| 1280 |
+
tracer = Tracer()
|
| 1281 |
+
graph = tracer.trace(root, concrete_args)
|
| 1282 |
+
name = (
|
| 1283 |
+
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
| 1284 |
+
)
|
| 1285 |
+
return _make_graph_module(tracer.root, graph, name)
|
| 1286 |
+
|
| 1287 |
+
|
| 1288 |
+
@wrap
|
| 1289 |
+
def _assert_is_none(value, msg):
|
| 1290 |
+
assert value is None, msg
|
.venv/lib/python3.11/site-packages/torch/fx/_utils.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._logging import LazyString
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
|
| 10 |
+
"""
|
| 11 |
+
Returns a LazyString that formats the graph code.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def format_name():
|
| 15 |
+
if maybe_id is not None:
|
| 16 |
+
return f"{name} {maybe_id}"
|
| 17 |
+
else:
|
| 18 |
+
return name
|
| 19 |
+
|
| 20 |
+
if "print_output" not in kwargs:
|
| 21 |
+
kwargs["print_output"] = False
|
| 22 |
+
|
| 23 |
+
if "colored" in kwargs and not sys.stdout.isatty():
|
| 24 |
+
kwargs["colored"] = False
|
| 25 |
+
|
| 26 |
+
return LazyString(
|
| 27 |
+
lambda: _format_graph_code(
|
| 28 |
+
f"===== {format_name()} =====\n",
|
| 29 |
+
gm.forward.__code__.co_filename,
|
| 30 |
+
gm.print_readable(**kwargs),
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _format_graph_code(name, filename, graph_str):
|
| 36 |
+
"""
|
| 37 |
+
Returns a string that formats the graph code.
|
| 38 |
+
"""
|
| 39 |
+
return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
|
| 43 |
+
"""
|
| 44 |
+
Returns the nn_module_stack of the first call_function node.
|
| 45 |
+
"""
|
| 46 |
+
for node in graph.nodes:
|
| 47 |
+
if node.op == "call_function" and "nn_module_stack" in node.meta:
|
| 48 |
+
return node.meta["nn_module_stack"]
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_node_context(node, num_nodes=2) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Returns a string of the last num_nodes nodes in the graph.
|
| 55 |
+
"""
|
| 56 |
+
node_contexts = []
|
| 57 |
+
cur = node
|
| 58 |
+
for i in range(num_nodes):
|
| 59 |
+
node_contexts.append(cur.format_node())
|
| 60 |
+
if cur.op == "root":
|
| 61 |
+
break
|
| 62 |
+
cur = cur.prev
|
| 63 |
+
return "\n".join(node_contexts[::-1])
|
.venv/lib/python3.11/site-packages/torch/fx/annotate.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch.fx.proxy import Proxy
|
| 3 |
+
from ._compatibility import compatibility
|
| 4 |
+
|
| 5 |
+
@compatibility(is_backward_compatible=False)
|
| 6 |
+
def annotate(val, type):
|
| 7 |
+
"""
|
| 8 |
+
Annotates a Proxy object with a given type.
|
| 9 |
+
|
| 10 |
+
This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object
|
| 11 |
+
Args:
|
| 12 |
+
val (object): An object to be annotated if its type is torch.fx.Proxy.
|
| 13 |
+
type (object): A type to be assigned to a given proxy object as val.
|
| 14 |
+
Returns:
|
| 15 |
+
The given val.
|
| 16 |
+
Raises:
|
| 17 |
+
RuntimeError: If a val already has a type in its node.
|
| 18 |
+
"""
|
| 19 |
+
if isinstance(val, Proxy):
|
| 20 |
+
if val.node.type:
|
| 21 |
+
raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
|
| 22 |
+
f" Existing type is {val.node.type} "
|
| 23 |
+
f"and new type is {type}. "
|
| 24 |
+
f"This could happen if you tried to annotate a function parameter "
|
| 25 |
+
f"value (in which case you should use the type slot "
|
| 26 |
+
f"on the function signature) or you called "
|
| 27 |
+
f"annotate on the same value twice")
|
| 28 |
+
else:
|
| 29 |
+
val.node.type = type
|
| 30 |
+
return val
|
| 31 |
+
else:
|
| 32 |
+
return val
|
.venv/lib/python3.11/site-packages/torch/fx/config.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Whether to disable showing progress on compilation passes
|
| 2 |
+
# Need to add a new config otherwise wil get a circular import if dynamo config is imported here
|
| 3 |
+
disable_progress = True
|
| 4 |
+
|
| 5 |
+
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
|
| 6 |
+
verbose_progress = False
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc
ADDED
|
Binary file (47.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc
ADDED
|
Binary file (49.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc
ADDED
|
Binary file (7.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc
ADDED
|
Binary file (16.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc
ADDED
|
Binary file (26.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc
ADDED
|
Binary file (8.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc
ADDED
|
Binary file (61.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc
ADDED
|
Binary file (41.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.fx
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BackwardState:
|
| 5 |
+
"""
|
| 6 |
+
BackwardState is used to pass Python hooks from the forwards pass
|
| 7 |
+
into the backwards pass in Dynamo+Compiled Autograd.
|
| 8 |
+
|
| 9 |
+
It is created by TorchDynamo and has special handling there.
|
| 10 |
+
Dynamo will pass an empty BackwardState to the forwards, then populate
|
| 11 |
+
members on it (via setattr) only after the forwards graph is finished.
|
| 12 |
+
Later on, in CompileAutograd we will inline and add the needed guards
|
| 13 |
+
on the BackwardState.
|
| 14 |
+
|
| 15 |
+
BackwardState is identified and has special handling in AOTAutograd.
|
| 16 |
+
During AOTAutograd:
|
| 17 |
+
1) BackwardState is an input to the forwards graph
|
| 18 |
+
2) It must only be used in the backwards
|
| 19 |
+
3) It will be empty in the forwards
|
| 20 |
+
4) In the forwards we add a wrapper to save it
|
| 21 |
+
5) In the backwards it becomes an input
|
| 22 |
+
6) There can only be one per graph
|
| 23 |
+
|
| 24 |
+
BackwardState requires CompiledAutograd.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
proxy: torch.fx.Proxy
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
|
| 7 |
+
translation_validation = (
|
| 8 |
+
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
|
| 9 |
+
)
|
| 10 |
+
# Timeout (in milliseconds) for z3 finding a solution.
|
| 11 |
+
# [@compile_ignored: debug]
|
| 12 |
+
translation_validation_timeout = int(
|
| 13 |
+
os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
|
| 14 |
+
)
|
| 15 |
+
# Disables bisection for translation validation.
|
| 16 |
+
#
|
| 17 |
+
# Translation validation bisection is enabled by default, if translation validation
|
| 18 |
+
# is also enabled. This should help finding guard simplification issues. However,
|
| 19 |
+
# since validation uses Z3 for bisecting, it might take a lot of time.
|
| 20 |
+
#
|
| 21 |
+
# Set this configuration option so as to avoid bisecting.
|
| 22 |
+
# [@compile_ignored: debug]
|
| 23 |
+
translation_validation_no_bisect = (
|
| 24 |
+
os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
|
| 25 |
+
)
|
| 26 |
+
# Checks whether replaying ShapeEnv events on a freshly constructed one yields
|
| 27 |
+
# the a ShapeEnv with the same state. This should be used only in testing.
|
| 28 |
+
check_shape_env_recorded_events = False
|
| 29 |
+
|
| 30 |
+
# TODO: Perhaps consider allowing unions for the configs below (so you can hit
|
| 31 |
+
# multiple reps at the same time)
|
| 32 |
+
|
| 33 |
+
# Give extended debug information if the string representation of a guard
|
| 34 |
+
# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue
|
| 35 |
+
# this guard, we will generate full Python and C++ backtrace
|
| 36 |
+
# [@compile_ignored: debug]
|
| 37 |
+
extended_debug_guard_added = os.environ.get(
|
| 38 |
+
"TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Give extended debug information when a particular symbol is allocated. For
|
| 42 |
+
# example, set this to "u2" and whenever we create this symbol, we will
|
| 43 |
+
# generate full Python and C++ backtrace
|
| 44 |
+
# [@compile_ignored: debug]
|
| 45 |
+
extended_debug_create_symbol = os.environ.get(
|
| 46 |
+
"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Give extended debug information (C++ backtrace) for all extended debug
|
| 50 |
+
# settings as well as errors. The C++ backtrace is slow and very spammy so we
|
| 51 |
+
# don't include it by default even when you're requesting extended debug.
|
| 52 |
+
# [@compile_ignored: debug]
|
| 53 |
+
extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
|
| 54 |
+
|
| 55 |
+
# Give extended debug information (line of code) when a torch function
|
| 56 |
+
# is called during export. This is useful for showing progress and detecting
|
| 57 |
+
# where export might be stuck. Currently only works for strict=False.
|
| 58 |
+
# [@compile_ignored: debug]
|
| 59 |
+
extended_debug_current_loc = (
|
| 60 |
+
os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# [@compile_ignored: debug] Show a warning for every specialization
|
| 64 |
+
print_specializations = False
|
| 65 |
+
|
| 66 |
+
# wraps (un)equalities with 'Not' class after recording the correct expression
|
| 67 |
+
# in the FX graph. This should incorrectly construct the divisible and replacement
|
| 68 |
+
# lists, and incorrectly issue guards.
|
| 69 |
+
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
|
| 70 |
+
|
| 71 |
+
# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
|
| 72 |
+
validate_shape_env_version_key = False
|
| 73 |
+
|
| 74 |
+
# If we produce more than this many guards on a symbol, force the symbol to
|
| 75 |
+
# get specialized and bail out if this many guards mention this particular
|
| 76 |
+
# symbol. This may be slightly more aggressive than the true number of guards
|
| 77 |
+
# issued (as we test if we've hit the limit on-the-fly, whereas we may
|
| 78 |
+
# do further simplifications at final guard issuance time that make guards
|
| 79 |
+
# irrelevant.)
|
| 80 |
+
symbol_guard_limit_before_specialize: Optional[int] = None
|
| 81 |
+
|
| 82 |
+
# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
|
| 83 |
+
use_duck_shape = True
|
| 84 |
+
|
| 85 |
+
from torch.utils._config_module import install_config_module
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
install_config_module(sys.modules[__name__])
|