Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py +1 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py +80 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +227 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py +102 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py +299 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +94 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py +318 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py +27 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py +179 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +673 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +110 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +65 -0
- .venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py +178 -0
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc
ADDED
|
Binary file (82.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc
ADDED
|
Binary file (24.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (43.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc
ADDED
|
Binary file (7.99 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc
ADDED
|
Binary file (1.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc
ADDED
|
Binary file (1.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (305 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc
ADDED
|
Binary file (5.73 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc
ADDED
|
Binary file (8.71 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc
ADDED
|
Binary file (28.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc
ADDED
|
Binary file (5.98 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc
ADDED
|
Binary file (8.48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.fx.graph_module import GraphModule
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None:
|
| 12 |
+
"""
|
| 13 |
+
Hook for adding the appropriate metadata to nodes that are created during a
|
| 14 |
+
pass using graph.create_node. An example of how to use it:
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
with _set_node_metadata_hook(gm,
|
| 18 |
+
functools.partial(_node_metadata_hook, stack_trace="file")
|
| 19 |
+
):
|
| 20 |
+
pass(gm)
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
This hook should not work for all generic cases -- specifically it assumes
|
| 24 |
+
that nodes being added are only call_function nodes, and copies over the
|
| 25 |
+
first argument node's nn_module_stack.
|
| 26 |
+
"""
|
| 27 |
+
assert node.op == "call_function" and callable(node.target)
|
| 28 |
+
|
| 29 |
+
arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
|
| 30 |
+
assert len(arg_meta) >= 1
|
| 31 |
+
arg_meta = arg_meta[0]
|
| 32 |
+
|
| 33 |
+
if (
|
| 34 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 35 |
+
and len(node.target._schema.returns) == 0
|
| 36 |
+
):
|
| 37 |
+
node.meta["val"] = None
|
| 38 |
+
else:
|
| 39 |
+
fake_args = [
|
| 40 |
+
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
| 41 |
+
for arg in node.args
|
| 42 |
+
]
|
| 43 |
+
fake_res = node.target(*fake_args)
|
| 44 |
+
node.meta["val"] = fake_res
|
| 45 |
+
|
| 46 |
+
node.meta["stack_trace"] = stack_trace
|
| 47 |
+
node.meta["nn_module_stack"] = arg_meta.get(
|
| 48 |
+
"nn_module_stack",
|
| 49 |
+
{
|
| 50 |
+
_EMPTY_NN_MODULE_STACK_KEY: (
|
| 51 |
+
_EMPTY_NN_MODULE_STACK_KEY,
|
| 52 |
+
_EMPTY_NN_MODULE_STACK_KEY,
|
| 53 |
+
)
|
| 54 |
+
},
|
| 55 |
+
)
|
| 56 |
+
node.meta["torch_fn"] = (
|
| 57 |
+
f"{node.target.__name__}_0",
|
| 58 |
+
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@contextlib.contextmanager
|
| 63 |
+
def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
|
| 64 |
+
"""
|
| 65 |
+
Takes a callable which will be called after we create a new node. The
|
| 66 |
+
callable takes the newly created node as input and returns None.
|
| 67 |
+
"""
|
| 68 |
+
assert callable(f), "node_metadata_hook must be a callable."
|
| 69 |
+
|
| 70 |
+
# Add the hook to all submodules
|
| 71 |
+
for m in gm.modules():
|
| 72 |
+
if isinstance(m, GraphModule):
|
| 73 |
+
m._register_create_node_hook(f)
|
| 74 |
+
try:
|
| 75 |
+
yield
|
| 76 |
+
finally:
|
| 77 |
+
# Restore hook for all submodules
|
| 78 |
+
for m in gm.modules():
|
| 79 |
+
if isinstance(m, GraphModule):
|
| 80 |
+
m._unregister_create_node_hook(f)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import math
|
| 3 |
+
import operator
|
| 4 |
+
import traceback
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import Callable, Dict, List, NamedTuple, Set
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.fx
|
| 12 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 13 |
+
from torch.utils._sympy.numbers import int_oo
|
| 14 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 15 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 16 |
+
|
| 17 |
+
__all__ = ["InputDim"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class InputDim(NamedTuple):
|
| 21 |
+
input_name: str
|
| 22 |
+
dim: int
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _convert_to_int(val):
|
| 26 |
+
# Convert simple sympy Integers into concrete int
|
| 27 |
+
if val in (sympy.oo, int_oo):
|
| 28 |
+
return math.inf
|
| 29 |
+
if val in (-sympy.oo, -int_oo):
|
| 30 |
+
return -math.inf
|
| 31 |
+
if isinstance(val, sympy.Integer):
|
| 32 |
+
return int(val)
|
| 33 |
+
raise RuntimeError(
|
| 34 |
+
"Export constraints cannot be non-integer expressions"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _convert_range_to_int(range: ValueRanges):
|
| 39 |
+
assert isinstance(range, ValueRanges)
|
| 40 |
+
min_val = _convert_to_int(range.lower)
|
| 41 |
+
max_val = _convert_to_int(range.upper)
|
| 42 |
+
return min_val, max_val
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
|
| 52 |
+
self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
|
| 53 |
+
self.counter = 0
|
| 54 |
+
|
| 55 |
+
def _assert_range_constraint(self, node, lower, upper, assert_msg):
|
| 56 |
+
last_node = node
|
| 57 |
+
if lower > -math.inf:
|
| 58 |
+
last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg)
|
| 59 |
+
|
| 60 |
+
if upper < math.inf:
|
| 61 |
+
last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg)
|
| 62 |
+
|
| 63 |
+
def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
|
| 64 |
+
"""
|
| 65 |
+
Inserts assert_async call_function nodes in the graph. This function is
|
| 66 |
+
called **during** the interpreter-based pass.
|
| 67 |
+
"""
|
| 68 |
+
self.counter += 1
|
| 69 |
+
graph = last_node.graph
|
| 70 |
+
with graph.inserting_after(last_node):
|
| 71 |
+
cmp = graph.call_function(op, (lower, upper), {})
|
| 72 |
+
with graph.inserting_after(cmp):
|
| 73 |
+
cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {})
|
| 74 |
+
with graph.inserting_after(cmp_tensor):
|
| 75 |
+
assert_async = graph.call_function(
|
| 76 |
+
torch.ops.aten._assert_async.msg,
|
| 77 |
+
(cmp_tensor, assert_msg),
|
| 78 |
+
{},
|
| 79 |
+
)
|
| 80 |
+
return assert_async
|
| 81 |
+
|
| 82 |
+
def call(self, graph_module) -> PassResult:
|
| 83 |
+
self.existing_inline_assertions = _get_existing_inline_assertions(
|
| 84 |
+
graph_module, self.range_constraints
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
for module in graph_module.modules():
|
| 88 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 89 |
+
continue
|
| 90 |
+
for node in module.graph.nodes:
|
| 91 |
+
if node.op != "call_function":
|
| 92 |
+
continue
|
| 93 |
+
if "val" not in node.meta:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
val = node.meta["val"]
|
| 97 |
+
# In general, we may have to deal the case such as: ret[1].shape[0].
|
| 98 |
+
# We need first find out what symbols require assertion, then we need to follow the path
|
| 99 |
+
# from ret to the symbol, construct the proxies along the way and construct the messages
|
| 100 |
+
# piece-wise at the same time.
|
| 101 |
+
#
|
| 102 |
+
# We use post-order traversal to collect all the proxies callbacks needed, construct
|
| 103 |
+
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
|
| 104 |
+
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
|
| 105 |
+
# need the proxy for shape, which further requires the proxy for ret[1], etc.
|
| 106 |
+
|
| 107 |
+
def add_assertions(val):
|
| 108 |
+
call_backs: List[Callable] = []
|
| 109 |
+
messages: List[str] = []
|
| 110 |
+
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
| 111 |
+
symbol = val.node.expr
|
| 112 |
+
if symbol in self.existing_inline_assertions:
|
| 113 |
+
return call_backs, messages
|
| 114 |
+
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
|
| 115 |
+
if symbol in self._asserts_generated_unbacked_symbols:
|
| 116 |
+
return call_backs, messages
|
| 117 |
+
# We only care about unbacked symints for these inline
|
| 118 |
+
# constraints, which are prefixed with 'u'
|
| 119 |
+
constraint = self.range_constraints[symbol]
|
| 120 |
+
min_val, max_val = _convert_range_to_int(constraint)
|
| 121 |
+
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
| 122 |
+
call_backs.append(
|
| 123 |
+
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
|
| 124 |
+
)
|
| 125 |
+
messages.append(assert_msg)
|
| 126 |
+
self._asserts_generated_unbacked_symbols.add(symbol)
|
| 127 |
+
|
| 128 |
+
elif isinstance(val, torch.Tensor):
|
| 129 |
+
for i, sym in enumerate(val.shape):
|
| 130 |
+
cbs, msgs = add_assertions(sym)
|
| 131 |
+
for cb, msg in zip(cbs, msgs):
|
| 132 |
+
def sym_size_cb(node, assert_msg, dim):
|
| 133 |
+
with node.graph.inserting_after(node):
|
| 134 |
+
dim_node = module.graph.call_function(
|
| 135 |
+
torch.ops.aten.sym_size.int,
|
| 136 |
+
(node, dim),
|
| 137 |
+
{},
|
| 138 |
+
)
|
| 139 |
+
cb(node=dim_node, assert_msg=assert_msg)
|
| 140 |
+
call_backs.append(partial(sym_size_cb, dim=i))
|
| 141 |
+
messages.append(f".shape[{i}]" + msg)
|
| 142 |
+
return call_backs, messages
|
| 143 |
+
|
| 144 |
+
callbacks, messages = add_assertions(val)
|
| 145 |
+
for cb, msg in zip(callbacks, messages):
|
| 146 |
+
cb(node=node, assert_msg=f"{node}" + msg)
|
| 147 |
+
|
| 148 |
+
module.recompile()
|
| 149 |
+
|
| 150 |
+
# Sometimes this pass would return a wrong graph where we have mismatched
|
| 151 |
+
# node names in signature. Before we fix it, let's just skip it.
|
| 152 |
+
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
|
| 153 |
+
return PassResult(graph_module, False)
|
| 154 |
+
|
| 155 |
+
# Populate the stack trace with dummy vals to respect IR
|
| 156 |
+
for node in graph_module.graph.nodes:
|
| 157 |
+
if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]:
|
| 158 |
+
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
|
| 159 |
+
return PassResult(graph_module, True)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _get_existing_inline_assertions(
|
| 163 |
+
graph_module: torch.fx.GraphModule,
|
| 164 |
+
range_constraints: Dict[sympy.Symbol, ValueRanges],
|
| 165 |
+
) -> Dict[sympy.Symbol, ValueRanges]:
|
| 166 |
+
existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
|
| 167 |
+
|
| 168 |
+
for module in graph_module.modules():
|
| 169 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Find all the existing inline assertions. They will look something like:
|
| 173 |
+
# %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
|
| 174 |
+
# %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
|
| 175 |
+
# %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {})
|
| 176 |
+
for node in module.graph.nodes:
|
| 177 |
+
if node.target != torch.ops.aten._assert_scalar.default:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
compare_arg = node.args[0]
|
| 181 |
+
if not (
|
| 182 |
+
isinstance(compare_arg, torch.fx.Node) and
|
| 183 |
+
compare_arg.op == "call_function" and
|
| 184 |
+
compare_arg.target in (operator.le, operator.ge) and
|
| 185 |
+
len(compare_arg.args) == 2
|
| 186 |
+
):
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
compare_op = compare_arg.target
|
| 190 |
+
lhs, rhs = compare_arg.args
|
| 191 |
+
|
| 192 |
+
def maybe_get_symint(x):
|
| 193 |
+
if (
|
| 194 |
+
isinstance(x, torch.fx.Node) and
|
| 195 |
+
"val" in x.meta and
|
| 196 |
+
isinstance(x.meta["val"], torch.SymInt)
|
| 197 |
+
):
|
| 198 |
+
return x.meta["val"].node.expr
|
| 199 |
+
return x
|
| 200 |
+
|
| 201 |
+
lhs = maybe_get_symint(lhs)
|
| 202 |
+
rhs = maybe_get_symint(rhs)
|
| 203 |
+
|
| 204 |
+
if compare_op == operator.ge:
|
| 205 |
+
lhs, rhs = rhs, lhs
|
| 206 |
+
|
| 207 |
+
if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int):
|
| 208 |
+
symint = lhs
|
| 209 |
+
scalar = rhs
|
| 210 |
+
elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int):
|
| 211 |
+
symint = rhs
|
| 212 |
+
scalar = lhs
|
| 213 |
+
else:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
if symint not in range_constraints:
|
| 217 |
+
raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
|
| 218 |
+
|
| 219 |
+
previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
|
| 220 |
+
|
| 221 |
+
if symint is lhs:
|
| 222 |
+
bounds = ValueRanges(-math.inf, scalar)
|
| 223 |
+
else:
|
| 224 |
+
bounds = ValueRanges(scalar, math.inf)
|
| 225 |
+
existing_inline_assertions[symint] = previous_range & bounds
|
| 226 |
+
|
| 227 |
+
return existing_inline_assertions
|
.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import operator
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.export.exported_program import ConstantArgument, TensorArgument
|
| 6 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ["CollectTracepointsPass"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CollectTracepointsPass(PassBase):
|
| 13 |
+
"""
|
| 14 |
+
Performs constant folding and constant propagation.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, specs, sig) -> None:
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.specs = specs
|
| 20 |
+
self.sig = sig
|
| 21 |
+
|
| 22 |
+
def call(self, gm):
|
| 23 |
+
def get_arg_spec(arg):
|
| 24 |
+
if isinstance(arg, torch.fx.Node):
|
| 25 |
+
if isinstance(arg.meta.get("val"), torch.Tensor):
|
| 26 |
+
return TensorArgument(name=arg.name)
|
| 27 |
+
else:
|
| 28 |
+
raise AssertionError(
|
| 29 |
+
"Symint input is not implemented yet for submodule call signature."
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
return ConstantArgument(name="", value=arg)
|
| 33 |
+
|
| 34 |
+
for module in gm.modules():
|
| 35 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 36 |
+
continue
|
| 37 |
+
nn_module_stack = None
|
| 38 |
+
for node in module.graph.nodes:
|
| 39 |
+
if node.op != "call_function":
|
| 40 |
+
continue
|
| 41 |
+
if node.target == torch.ops.higher_order._export_tracepoint:
|
| 42 |
+
kind = node.kwargs["kind"]
|
| 43 |
+
if kind == "module_call_outputs":
|
| 44 |
+
nn_module_stack = node.meta["nn_module_stack"]
|
| 45 |
+
elif kind == "module_call_inputs":
|
| 46 |
+
nn_module_stack = None
|
| 47 |
+
else:
|
| 48 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 49 |
+
elif node.meta["nn_module_stack"] == nn_module_stack:
|
| 50 |
+
node.meta["nn_module_stack"].popitem()
|
| 51 |
+
else:
|
| 52 |
+
nn_module_stack = None
|
| 53 |
+
nn_module_stack = None
|
| 54 |
+
for node in reversed(module.graph.nodes):
|
| 55 |
+
if node.op != "call_function":
|
| 56 |
+
continue
|
| 57 |
+
if node.target == torch.ops.higher_order._export_tracepoint:
|
| 58 |
+
kind = node.kwargs["kind"]
|
| 59 |
+
if kind == "module_call_inputs":
|
| 60 |
+
nn_module_stack = node.meta["nn_module_stack"]
|
| 61 |
+
elif kind == "module_call_outputs":
|
| 62 |
+
nn_module_stack = None
|
| 63 |
+
else:
|
| 64 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 65 |
+
elif node.meta["nn_module_stack"] == nn_module_stack:
|
| 66 |
+
node.meta["nn_module_stack"].popitem()
|
| 67 |
+
else:
|
| 68 |
+
nn_module_stack = None
|
| 69 |
+
for module in gm.modules():
|
| 70 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 71 |
+
continue
|
| 72 |
+
for node in module.graph.nodes:
|
| 73 |
+
if node.op != "call_function":
|
| 74 |
+
continue
|
| 75 |
+
if node.target == torch.ops.higher_order._export_tracepoint:
|
| 76 |
+
for i, arg in enumerate(node.args):
|
| 77 |
+
kind = node.kwargs["kind"]
|
| 78 |
+
if kind == "module_call_inputs":
|
| 79 |
+
self.specs[node.kwargs["path"]].inputs.append(
|
| 80 |
+
get_arg_spec(arg)
|
| 81 |
+
)
|
| 82 |
+
elif kind == "module_call_outputs":
|
| 83 |
+
self.specs[node.kwargs["path"]].outputs.append(
|
| 84 |
+
get_arg_spec(arg)
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 88 |
+
if isinstance(arg, torch.fx.Node):
|
| 89 |
+
for user in node.users:
|
| 90 |
+
assert user.op == "call_function"
|
| 91 |
+
assert user.target == operator.getitem
|
| 92 |
+
assert isinstance(user.args[1], int)
|
| 93 |
+
if user.args[1] == i:
|
| 94 |
+
user.replace_all_uses_with(arg)
|
| 95 |
+
self.sig.replace_all_uses(user.name, arg.name)
|
| 96 |
+
break
|
| 97 |
+
users = list(node.users)
|
| 98 |
+
for user in users:
|
| 99 |
+
assert len(user.users) == 0
|
| 100 |
+
gm.graph.erase_node(user)
|
| 101 |
+
gm.graph.erase_node(node)
|
| 102 |
+
return PassResult(gm, True)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Any, Callable, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils._pytree as pytree
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
aten = torch.ops.aten
|
| 11 |
+
|
| 12 |
+
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
|
| 13 |
+
# The use case and more information could be found at:
|
| 14 |
+
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
|
| 15 |
+
META_TAG = "MODULE_TYPE"
|
| 16 |
+
MODULE_TAG = "_MAIN_MODULE"
|
| 17 |
+
CONST_MODULE_TAG = "_CONST_MODULE"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def replace_node_with_constant(gm, node, constant, name=None):
|
| 21 |
+
g = gm.graph
|
| 22 |
+
|
| 23 |
+
if name:
|
| 24 |
+
qualname = name
|
| 25 |
+
else:
|
| 26 |
+
if not hasattr(gm, "_frozen_param_count"):
|
| 27 |
+
gm._frozen_param_count = 0
|
| 28 |
+
i = gm._frozen_param_count
|
| 29 |
+
|
| 30 |
+
while True:
|
| 31 |
+
qualname = f"_frozen_param{i}"
|
| 32 |
+
if not hasattr(gm, qualname):
|
| 33 |
+
break
|
| 34 |
+
i += 1
|
| 35 |
+
|
| 36 |
+
gm._frozen_param_count = i + 1
|
| 37 |
+
|
| 38 |
+
with g.inserting_before(node):
|
| 39 |
+
new_input_node = g.create_node("get_attr", qualname, (), {})
|
| 40 |
+
node.replace_all_uses_with(new_input_node)
|
| 41 |
+
new_input_node.meta.update(node.meta)
|
| 42 |
+
g.erase_node(node)
|
| 43 |
+
|
| 44 |
+
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
|
| 45 |
+
gm.register_buffer(qualname, constant)
|
| 46 |
+
setattr(gm, qualname, constant)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ConstantFolder(torch.fx.Interpreter):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
gm,
|
| 53 |
+
skip_constructors=False,
|
| 54 |
+
):
|
| 55 |
+
super().__init__(gm)
|
| 56 |
+
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
| 57 |
+
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
| 58 |
+
self.unknown_value = object()
|
| 59 |
+
self.skip_constructors: bool = skip_constructors
|
| 60 |
+
|
| 61 |
+
# overwrite this to deallocate env values if their only remaining use
|
| 62 |
+
# is the output
|
| 63 |
+
self.user_to_last_uses = self.node_to_last_non_output_use()
|
| 64 |
+
|
| 65 |
+
def is_impure(self, node: torch.fx.node.Node):
|
| 66 |
+
if (
|
| 67 |
+
node.target == torch.ops.prims.convert_element_type.default
|
| 68 |
+
and node.args[0].op == "get_attr" # type: ignore[union-attr]
|
| 69 |
+
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
|
| 70 |
+
and node.args[1] == torch.bfloat16
|
| 71 |
+
):
|
| 72 |
+
# For int8_weight -> dq -> bf16_weight
|
| 73 |
+
return True
|
| 74 |
+
if node.target in [
|
| 75 |
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
| 76 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
| 77 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
| 78 |
+
]:
|
| 79 |
+
# For the pattern fp32_weight -> q -> dq
|
| 80 |
+
# We only folding fp32_weight -> q
|
| 81 |
+
# int8_weight and leave dq in graph to be fused
|
| 82 |
+
return True
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
def node_to_last_non_output_use(self):
|
| 86 |
+
last_non_output_use = collections.defaultdict(list)
|
| 87 |
+
seen_uses = set()
|
| 88 |
+
output_node = next(iter(reversed(self.module.graph.nodes)))
|
| 89 |
+
|
| 90 |
+
for node in reversed(self.module.graph.nodes):
|
| 91 |
+
if node.target == "output":
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
def add_use(inp):
|
| 95 |
+
if inp in seen_uses:
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
seen_uses.add(inp)
|
| 99 |
+
last_non_output_use[node].append(inp)
|
| 100 |
+
|
| 101 |
+
# In-place is fine since we don't mutate
|
| 102 |
+
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
|
| 103 |
+
|
| 104 |
+
# if this node is only used in output, we want to gc it right away
|
| 105 |
+
if len(node.users) == 1 and output_node in node.users:
|
| 106 |
+
last_non_output_use[node].append(node)
|
| 107 |
+
|
| 108 |
+
return last_non_output_use
|
| 109 |
+
|
| 110 |
+
def run_node(self, node):
|
| 111 |
+
if node.target == "output":
|
| 112 |
+
# because we remove nodes from env on last non output use,
|
| 113 |
+
# re-define them now or we'll get error in interpreter
|
| 114 |
+
def set_env(arg):
|
| 115 |
+
self.env[arg] = self.unknown_value
|
| 116 |
+
|
| 117 |
+
# In-place is fine since we don't mutate
|
| 118 |
+
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
|
| 119 |
+
return super().run_node(node)
|
| 120 |
+
|
| 121 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 122 |
+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
| 123 |
+
|
| 124 |
+
# We need to do this weird thing because in cases where flattened_inputs
|
| 125 |
+
# contains a ScriptObject, equality checking results in a type error if
|
| 126 |
+
# the types are different.
|
| 127 |
+
if any(
|
| 128 |
+
type(self.unknown_value) == type(input_) and self.unknown_value == input_
|
| 129 |
+
for input_ in flattened_inputs
|
| 130 |
+
):
|
| 131 |
+
return self.unknown_value
|
| 132 |
+
|
| 133 |
+
# TODO - fix errors with this
|
| 134 |
+
if (
|
| 135 |
+
node.op == "call_function"
|
| 136 |
+
and node.target == aten._efficientzerotensor.default
|
| 137 |
+
):
|
| 138 |
+
return self.unknown_value
|
| 139 |
+
|
| 140 |
+
# TODO - constant folding triton kernel returns the inputs -- fix this
|
| 141 |
+
if (
|
| 142 |
+
node.op == "call_function"
|
| 143 |
+
and node.name == "triton_kernel_wrapper_functional_proxy"
|
| 144 |
+
):
|
| 145 |
+
return self.unknown_value
|
| 146 |
+
|
| 147 |
+
# skip constructors, since inductor generates optimal code for them already
|
| 148 |
+
# and turning into tensor would result in an additional global memory read
|
| 149 |
+
# TODO - more complicated strategy
|
| 150 |
+
if (
|
| 151 |
+
self.skip_constructors
|
| 152 |
+
and node.op != "get_attr"
|
| 153 |
+
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
|
| 154 |
+
):
|
| 155 |
+
return self.unknown_value
|
| 156 |
+
|
| 157 |
+
# All mutations should either be removed or on inputs which we did not make constant
|
| 158 |
+
if (
|
| 159 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 160 |
+
and torch.Tag.nondeterministic_seeded in node.target.tags
|
| 161 |
+
):
|
| 162 |
+
return self.unknown_value
|
| 163 |
+
|
| 164 |
+
out = super().run_node(node)
|
| 165 |
+
|
| 166 |
+
if node.op != "get_attr" and isinstance(out, torch.Tensor):
|
| 167 |
+
if out.device.type == "meta":
|
| 168 |
+
return out
|
| 169 |
+
|
| 170 |
+
if not self.insertable_tensor_check(out):
|
| 171 |
+
return out
|
| 172 |
+
|
| 173 |
+
if self.is_impure(node):
|
| 174 |
+
return self.unknown_value
|
| 175 |
+
|
| 176 |
+
self.add_node_replacement(node, out)
|
| 177 |
+
|
| 178 |
+
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
| 179 |
+
|
| 180 |
+
for n in flattened_node_inps:
|
| 181 |
+
if not isinstance(n, torch.fx.Node):
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
self.replaced_uses[n] += 1
|
| 185 |
+
|
| 186 |
+
for to_delete in self.user_to_last_uses.get(node, []):
|
| 187 |
+
if self.replaced_uses[to_delete] == len(to_delete.users):
|
| 188 |
+
self.node_replacements.pop(to_delete, None)
|
| 189 |
+
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 196 |
+
self.node_replacements[node] = tensor
|
| 197 |
+
|
| 198 |
+
def run(self):
|
| 199 |
+
env = {}
|
| 200 |
+
for n in self.module.graph.find_nodes(op="placeholder"):
|
| 201 |
+
env[n] = self.unknown_value
|
| 202 |
+
return super().run(initial_env=env)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
|
| 206 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 207 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 208 |
+
cf.run()
|
| 209 |
+
|
| 210 |
+
for node, constant in cf.node_replacements.items():
|
| 211 |
+
if constraint_fn is not None and not constraint_fn(node):
|
| 212 |
+
continue
|
| 213 |
+
replace_node_with_constant(gm, node, constant)
|
| 214 |
+
|
| 215 |
+
erased_params = []
|
| 216 |
+
# Get all attr users by looking up the graph instead from node.users, because in this case
|
| 217 |
+
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
|
| 218 |
+
|
| 219 |
+
# opcode name target args kwargs
|
| 220 |
+
# ------------- ------------------- ---------------- --------------------------- --------
|
| 221 |
+
# placeholder arg0_1 arg0 () {}
|
| 222 |
+
# get_attr _tensor_constant0 state () {}
|
| 223 |
+
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
|
| 224 |
+
# get_attr _tensor_constant0_1 state () {}
|
| 225 |
+
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
|
| 226 |
+
# output output output ([add],) {}
|
| 227 |
+
|
| 228 |
+
get_attr_node_users = defaultdict(list)
|
| 229 |
+
for node in gm.graph.nodes:
|
| 230 |
+
if node.op == "get_attr":
|
| 231 |
+
get_attr_node_users[node.target].extend(node.users.keys())
|
| 232 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 233 |
+
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
|
| 234 |
+
if hasattr(gm, node.target):
|
| 235 |
+
delattr(gm, node.target)
|
| 236 |
+
erased_params.append(node)
|
| 237 |
+
for node in erased_params:
|
| 238 |
+
gm.graph.erase_node(node)
|
| 239 |
+
|
| 240 |
+
gm.graph.eliminate_dead_code()
|
| 241 |
+
gm.graph.lint()
|
| 242 |
+
gm.recompile()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def constant_graph_tag(gm: torch.fx.GraphModule):
|
| 246 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 247 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 248 |
+
cf.run()
|
| 249 |
+
|
| 250 |
+
for node in gm.graph.nodes:
|
| 251 |
+
if (
|
| 252 |
+
node.op == "get_attr"
|
| 253 |
+
or node in cf.node_replacements
|
| 254 |
+
or node in cf.replaced_uses
|
| 255 |
+
):
|
| 256 |
+
node.meta[META_TAG] = CONST_MODULE_TAG
|
| 257 |
+
else:
|
| 258 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 262 |
+
"""
|
| 263 |
+
Construct a GraphModule which corresponds to the part which could be
|
| 264 |
+
constant folded in provided gm.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
constant_graph_tag(gm)
|
| 268 |
+
# We rewrite the tags, if it's a constant being directly consumed, without
|
| 269 |
+
# any folding opportunity, we keep it in main gm.
|
| 270 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 271 |
+
used_to_fold = False
|
| 272 |
+
for u in node.users:
|
| 273 |
+
if u.meta[META_TAG] == CONST_MODULE_TAG:
|
| 274 |
+
used_to_fold = True
|
| 275 |
+
break
|
| 276 |
+
if not used_to_fold:
|
| 277 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 278 |
+
|
| 279 |
+
new_graph = torch.fx.Graph()
|
| 280 |
+
|
| 281 |
+
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 282 |
+
output_nodes = []
|
| 283 |
+
for node in gm.graph.nodes:
|
| 284 |
+
if node.meta[META_TAG] == MODULE_TAG:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
|
| 288 |
+
node_remapping[node] = new_node
|
| 289 |
+
|
| 290 |
+
for user in node.users:
|
| 291 |
+
if user.meta[META_TAG] == MODULE_TAG:
|
| 292 |
+
output_nodes.append(new_node)
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
new_graph.output(tuple(output_nodes))
|
| 296 |
+
new_graph.lint()
|
| 297 |
+
new_gm = torch.fx.GraphModule(gm, new_graph)
|
| 298 |
+
|
| 299 |
+
return new_gm
|
.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Dict, Optional, Tuple, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
|
| 6 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 7 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 8 |
+
from torch._ops import OpOverload
|
| 9 |
+
|
| 10 |
+
aten = torch.ops.aten
|
| 11 |
+
|
| 12 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
|
| 13 |
+
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
|
| 14 |
+
aten._assert_async.msg: aten._functional_assert_async.msg,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 19 |
+
"""
|
| 20 |
+
Functionalize ops with side effect in graph module by replacing the op with
|
| 21 |
+
functional version of it. A new dependency token (`dep_token`) will be
|
| 22 |
+
created and propagated through functional ops to output.
|
| 23 |
+
For example:
|
| 24 |
+
```
|
| 25 |
+
def f(x):
|
| 26 |
+
sym_constrain_range(x.shape[0], min=1, max=3)
|
| 27 |
+
return x.add(3)
|
| 28 |
+
```
|
| 29 |
+
Will be transformed to:
|
| 30 |
+
```
|
| 31 |
+
def f(x):
|
| 32 |
+
dep_token0 = _make_dep_token()
|
| 33 |
+
dep_token1 = _functional_sym_constrain_range(
|
| 34 |
+
x.shape[0], min=1, max=3, dep_token=dep_token0
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return x.add(3), dep_token1
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self._dep_token: Optional[ProxyValue] = None
|
| 44 |
+
self._next_dep_token_index: Optional[int] = None
|
| 45 |
+
|
| 46 |
+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
| 47 |
+
# Early return if no non-functional assertions.
|
| 48 |
+
if not any(
|
| 49 |
+
n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
|
| 50 |
+
for n in graph_module.graph.nodes
|
| 51 |
+
):
|
| 52 |
+
return PassResult(graph_module=graph_module, modified=False)
|
| 53 |
+
|
| 54 |
+
gm = copy.deepcopy(graph_module)
|
| 55 |
+
self._dep_token = None
|
| 56 |
+
self._next_dep_token_index = None
|
| 57 |
+
return super().call(gm)
|
| 58 |
+
|
| 59 |
+
def call_operator(
|
| 60 |
+
self,
|
| 61 |
+
op: OpOverload,
|
| 62 |
+
args: Tuple[Argument, ...],
|
| 63 |
+
kwargs: Dict[str, Argument],
|
| 64 |
+
meta: NodeMetadata,
|
| 65 |
+
) -> ProxyValue:
|
| 66 |
+
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
|
| 67 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 68 |
+
|
| 69 |
+
if self._dep_token is None:
|
| 70 |
+
self._dep_token = super().call_operator(
|
| 71 |
+
aten._make_dep_token,
|
| 72 |
+
args=(),
|
| 73 |
+
kwargs={},
|
| 74 |
+
meta=self._create_dummy_node_metadata(),
|
| 75 |
+
)
|
| 76 |
+
self._dep_token.node.name = "dep_token0"
|
| 77 |
+
self._next_dep_token_index = 1
|
| 78 |
+
|
| 79 |
+
self._dep_token = super().call_operator(
|
| 80 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
|
| 81 |
+
args=args,
|
| 82 |
+
kwargs={**kwargs, "dep_token": self._dep_token},
|
| 83 |
+
meta=meta,
|
| 84 |
+
)
|
| 85 |
+
assert self._next_dep_token_index is not None
|
| 86 |
+
self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
|
| 87 |
+
self._next_dep_token_index += 1
|
| 88 |
+
|
| 89 |
+
return self._dep_token
|
| 90 |
+
|
| 91 |
+
def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
|
| 92 |
+
assert self._dep_token is not None
|
| 93 |
+
|
| 94 |
+
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
|
.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any, Dict, List, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._export.verifier import SpecViolationError
|
| 8 |
+
from torch._guards import detect_fake_mode
|
| 9 |
+
from torch._library.fake_class_registry import FakeScriptObject
|
| 10 |
+
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
| 11 |
+
from torch.export.exported_program import (
|
| 12 |
+
ArgumentSpec,
|
| 13 |
+
CustomObjArgument,
|
| 14 |
+
ExportGraphSignature,
|
| 15 |
+
InputKind,
|
| 16 |
+
InputSpec,
|
| 17 |
+
TensorArgument,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConstantAttrMap(collections.abc.MutableMapping):
|
| 22 |
+
"""A mapping class that understands how to use module constants (tensors,
|
| 23 |
+
ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
|
| 24 |
+
but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to
|
| 25 |
+
the same underlying value (but we guarantee that they will `hash()` to the same value
|
| 26 |
+
if that's the case).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self) -> None:
|
| 30 |
+
# Underlying dict that we use to implement this mapping.
|
| 31 |
+
self._constant_attrs: Dict[
|
| 32 |
+
Union[int, torch.Tensor, FakeScriptObject], List[Any]
|
| 33 |
+
] = {}
|
| 34 |
+
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for
|
| 35 |
+
# APIs like `__iter__` that should look like they're returning the
|
| 36 |
+
# original ScriptObjects.
|
| 37 |
+
self._script_object_map: Dict[int, torch.ScriptObject] = {}
|
| 38 |
+
|
| 39 |
+
def __getitem__(
|
| 40 |
+
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
|
| 41 |
+
) -> Any:
|
| 42 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 43 |
+
assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
|
| 44 |
+
return self._constant_attrs[real_key]
|
| 45 |
+
|
| 46 |
+
def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value):
|
| 47 |
+
# we shouldn't actually call this, should go to add() instead to handle aliasing
|
| 48 |
+
raise NotImplementedError(
|
| 49 |
+
"""Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
|
| 50 |
+
The same key can be mapped to multiple values, for handling constant aliasing."""
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def add(
|
| 54 |
+
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any
|
| 55 |
+
) -> None:
|
| 56 |
+
if isinstance(key, torch.ScriptObject):
|
| 57 |
+
if hash(key) not in self._constant_attrs:
|
| 58 |
+
self._constant_attrs[hash(key)] = []
|
| 59 |
+
self._constant_attrs[hash(key)].append(value)
|
| 60 |
+
self._script_object_map[hash(key)] = key
|
| 61 |
+
elif isinstance(key, (torch.Tensor, FakeScriptObject)):
|
| 62 |
+
if key not in self._constant_attrs:
|
| 63 |
+
self._constant_attrs[key] = []
|
| 64 |
+
self._constant_attrs[key].append(value)
|
| 65 |
+
else:
|
| 66 |
+
raise TypeError(
|
| 67 |
+
f"Expected key to be a tensor or ScriptObject, got {type(key)}"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def __delitem__(self, key):
|
| 71 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 72 |
+
|
| 73 |
+
del self._constant_attrs[real_key]
|
| 74 |
+
|
| 75 |
+
def __iter__(self):
|
| 76 |
+
for key in self._constant_attrs:
|
| 77 |
+
if isinstance(key, int):
|
| 78 |
+
yield self._script_object_map[key]
|
| 79 |
+
else:
|
| 80 |
+
yield key
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self._constant_attrs)
|
| 84 |
+
|
| 85 |
+
def __contains__(self, key: object) -> bool:
|
| 86 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 87 |
+
return real_key in self._constant_attrs
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
|
| 91 |
+
# The FQN of the constant tensor in the state dict should
|
| 92 |
+
# correspond to the module where the constant tensor was
|
| 93 |
+
# originally used.
|
| 94 |
+
if len(node.meta["nn_module_stack"]) == 0:
|
| 95 |
+
return constant_name
|
| 96 |
+
parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0]
|
| 97 |
+
if len(parent_fqn) > 0:
|
| 98 |
+
return f"{parent_fqn}.{constant_name}"
|
| 99 |
+
else:
|
| 100 |
+
return constant_name
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_first_fqn(
|
| 104 |
+
const_attrs: ConstantAttrMap,
|
| 105 |
+
key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],
|
| 106 |
+
) -> Any:
|
| 107 |
+
fqns = const_attrs.get(key)
|
| 108 |
+
return fqns[0] if fqns else None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def lift_constants_pass(
|
| 112 |
+
gm: torch.fx.GraphModule,
|
| 113 |
+
graph_signature: ExportGraphSignature,
|
| 114 |
+
constant_attrs: ConstantAttrMap,
|
| 115 |
+
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
|
| 116 |
+
"""
|
| 117 |
+
Takes a graph module, graph signature, and modifies them implace to lift any
|
| 118 |
+
constants (tensors or custom classes) as inputs to the graph. Returns a
|
| 119 |
+
dictionary of names to constants.
|
| 120 |
+
|
| 121 |
+
Arguments:
|
| 122 |
+
gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift.
|
| 123 |
+
graph_signature (ExportGraphSignature): This graph signature will be
|
| 124 |
+
mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs.
|
| 125 |
+
constant_attrs (ConstantAttr): A mapping from a constant value to its
|
| 126 |
+
fully-qualified path in `gm`. This is used to maintain consistent
|
| 127 |
+
location of constants between the original module and the exported
|
| 128 |
+
version.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
A dictionary of fqn => constant value.
|
| 132 |
+
"""
|
| 133 |
+
all_constants: Dict[
|
| 134 |
+
str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
|
| 135 |
+
] = {}
|
| 136 |
+
|
| 137 |
+
inputs = graph_signature.input_specs
|
| 138 |
+
num_custom_obj = sum(
|
| 139 |
+
input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs
|
| 140 |
+
)
|
| 141 |
+
num_tensor_constants = sum(
|
| 142 |
+
input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
fake_mode = detect_fake_mode(
|
| 146 |
+
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
first_user_input_loc, first_user_input = 0, None
|
| 150 |
+
for node in gm.graph.nodes:
|
| 151 |
+
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
|
| 152 |
+
first_user_input = node
|
| 153 |
+
break
|
| 154 |
+
first_user_input_loc += 1
|
| 155 |
+
|
| 156 |
+
lifted_objs = ConstantAttrMap()
|
| 157 |
+
for node in gm.graph.nodes:
|
| 158 |
+
if node.op == "get_attr":
|
| 159 |
+
constant_val = getattr(gm, node.target)
|
| 160 |
+
if constant_val in lifted_objs:
|
| 161 |
+
# We already lifted this constant elsewhere. Just rewrite uses
|
| 162 |
+
# of this get_attr to point to the already-existing placeholder
|
| 163 |
+
# node.
|
| 164 |
+
const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
|
| 165 |
+
node.replace_all_uses_with(const_placeholder_node)
|
| 166 |
+
gm.graph.erase_node(node)
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# For ScriptObject, Tensor and FakeScriptObject constants:
|
| 170 |
+
# First check if the constant was an attribute on some module by
|
| 171 |
+
# consulting `constant_attrs` map. If it is, use the fqn that keeps
|
| 172 |
+
# its location consistent with the eager module.
|
| 173 |
+
#
|
| 174 |
+
# If it's not in the `constant_attrs` map, that means it's an inline
|
| 175 |
+
# constant (e.g. x + torch.tensor(0)), and thus did not have a
|
| 176 |
+
# specific location in the eager module. In that case, just generate
|
| 177 |
+
# some name and attach it to the module in which it was used.
|
| 178 |
+
if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)):
|
| 179 |
+
constant_kind = InputKind.CUSTOM_OBJ
|
| 180 |
+
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
|
| 181 |
+
if constant_fqn is not None:
|
| 182 |
+
constant_name = constant_fqn.replace(".", "_")
|
| 183 |
+
else:
|
| 184 |
+
constant_name = f"lifted_custom_{num_custom_obj}"
|
| 185 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 186 |
+
num_custom_obj += 1
|
| 187 |
+
elif isinstance(constant_val, torch.Tensor):
|
| 188 |
+
# Remove the parameterness of constant_val
|
| 189 |
+
if isinstance(constant_val, torch.nn.Parameter):
|
| 190 |
+
warnings.warn(
|
| 191 |
+
f"{node.target} created when tracing {node.meta['stack_trace']} is a parameter. But"
|
| 192 |
+
f"it's not registered with register_parameter(). export will treat it as a constant tensor"
|
| 193 |
+
)
|
| 194 |
+
# We get the real data out of the parameter by disabling the surrounding fake mode.
|
| 195 |
+
with unset_fake_temporarily():
|
| 196 |
+
constant_val = constant_val.data
|
| 197 |
+
constant_kind = InputKind.CONSTANT_TENSOR
|
| 198 |
+
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
|
| 199 |
+
if constant_fqn is not None:
|
| 200 |
+
constant_name = constant_fqn.replace(".", "_")
|
| 201 |
+
else:
|
| 202 |
+
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
| 203 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 204 |
+
num_tensor_constants += 1
|
| 205 |
+
elif isinstance(constant_val, torch.fx.GraphModule):
|
| 206 |
+
continue
|
| 207 |
+
elif "LoweredBackendModule" in type(constant_val).__name__:
|
| 208 |
+
continue
|
| 209 |
+
else:
|
| 210 |
+
raise SpecViolationError(
|
| 211 |
+
f"getattr node {node} referencing unsupported type {type(constant_val)}"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
with gm.graph.inserting_before(first_user_input):
|
| 215 |
+
# Insert the constant node before the first user input
|
| 216 |
+
const_placeholder_node = gm.graph.placeholder(constant_name)
|
| 217 |
+
# match target name with its node name in case there is name collision
|
| 218 |
+
# and suffix is added to node name in fx
|
| 219 |
+
const_placeholder_node.target = const_placeholder_node.name
|
| 220 |
+
|
| 221 |
+
for k, v in node.meta.items():
|
| 222 |
+
const_placeholder_node.meta[k] = v
|
| 223 |
+
|
| 224 |
+
# Once the FQN has been used, remove nn_module_stack, stack_trace
|
| 225 |
+
const_placeholder_node.meta.pop("nn_module_stack")
|
| 226 |
+
const_placeholder_node.meta.pop("stack_trace", None)
|
| 227 |
+
|
| 228 |
+
input_spec_arg: ArgumentSpec
|
| 229 |
+
if isinstance(constant_val, torch.Tensor):
|
| 230 |
+
if fake_mode is not None:
|
| 231 |
+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
|
| 232 |
+
constant_val, static_shapes=True
|
| 233 |
+
)
|
| 234 |
+
const_placeholder_node.meta["val"].constant = constant_val
|
| 235 |
+
else:
|
| 236 |
+
const_placeholder_node.meta["val"] = constant_val
|
| 237 |
+
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
|
| 238 |
+
elif isinstance(constant_val, torch._C.ScriptObject):
|
| 239 |
+
class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
|
| 240 |
+
const_placeholder_node.meta["val"] = CustomObjArgument(
|
| 241 |
+
constant_fqn, class_fqn
|
| 242 |
+
)
|
| 243 |
+
input_spec_arg = CustomObjArgument(
|
| 244 |
+
name=const_placeholder_node.name, class_fqn=class_fqn
|
| 245 |
+
)
|
| 246 |
+
elif isinstance(constant_val, FakeScriptObject):
|
| 247 |
+
class_fqn = constant_val.script_class_name
|
| 248 |
+
const_placeholder_node.meta["val"] = CustomObjArgument(
|
| 249 |
+
constant_fqn, class_fqn, constant_val
|
| 250 |
+
)
|
| 251 |
+
input_spec_arg = CustomObjArgument(
|
| 252 |
+
name=const_placeholder_node.name,
|
| 253 |
+
class_fqn=class_fqn,
|
| 254 |
+
fake_val=constant_val,
|
| 255 |
+
)
|
| 256 |
+
else:
|
| 257 |
+
raise SpecViolationError(
|
| 258 |
+
f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
lifted_objs.add(constant_val, const_placeholder_node)
|
| 262 |
+
node.replace_all_uses_with(const_placeholder_node)
|
| 263 |
+
gm.graph.erase_node(node)
|
| 264 |
+
|
| 265 |
+
# Add the constant as a buffer to the graph signature
|
| 266 |
+
graph_signature.input_specs.insert(
|
| 267 |
+
first_user_input_loc,
|
| 268 |
+
InputSpec(
|
| 269 |
+
kind=constant_kind,
|
| 270 |
+
arg=input_spec_arg,
|
| 271 |
+
target=constant_fqn,
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
if constant_val in constant_attrs:
|
| 275 |
+
for fqn in constant_attrs[constant_val]:
|
| 276 |
+
all_constants[fqn] = constant_val
|
| 277 |
+
else:
|
| 278 |
+
all_constants[constant_fqn] = constant_val
|
| 279 |
+
first_user_input_loc += 1
|
| 280 |
+
|
| 281 |
+
return all_constants
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def rewrite_script_object_meta(
|
| 285 |
+
gm: torch.fx.GraphModule,
|
| 286 |
+
) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
|
| 287 |
+
"""When tracing, we produce a graph with FakeScriptObject in the
|
| 288 |
+
meta["val"].
|
| 289 |
+
|
| 290 |
+
For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
|
| 291 |
+
"""
|
| 292 |
+
constants: Dict[
|
| 293 |
+
str,
|
| 294 |
+
Union[
|
| 295 |
+
torch.Tensor,
|
| 296 |
+
torch.ScriptObject,
|
| 297 |
+
FakeScriptObject,
|
| 298 |
+
],
|
| 299 |
+
] = {}
|
| 300 |
+
for node in gm.graph.nodes:
|
| 301 |
+
if "val" not in node.meta:
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
old_meta = node.meta["val"]
|
| 305 |
+
|
| 306 |
+
if isinstance(old_meta, torch.ScriptObject):
|
| 307 |
+
class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
|
| 308 |
+
new_meta = CustomObjArgument(node.name, class_fqn)
|
| 309 |
+
constants[node.name] = old_meta
|
| 310 |
+
node.meta["val"] = new_meta
|
| 311 |
+
|
| 312 |
+
elif isinstance(old_meta, FakeScriptObject):
|
| 313 |
+
class_fqn = old_meta.script_class_name # type: ignore[attr-defined]
|
| 314 |
+
new_meta = CustomObjArgument(node.name, class_fqn, old_meta)
|
| 315 |
+
constants[node.name] = old_meta
|
| 316 |
+
node.meta["val"] = new_meta
|
| 317 |
+
|
| 318 |
+
return constants
|
.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import torch
|
| 3 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class _RemoveRuntimeAssertionsPass(PassBase):
|
| 7 |
+
"""
|
| 8 |
+
Remove runtime assertions inserted by the
|
| 9 |
+
_AddRuntimeAssertionsForInlineConstraintsPass.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def call(self, graph_module) -> PassResult:
|
| 13 |
+
modified = False
|
| 14 |
+
for module in graph_module.modules():
|
| 15 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 16 |
+
continue
|
| 17 |
+
for node in module.graph.nodes:
|
| 18 |
+
if node.target == torch.ops.aten._assert_async.msg:
|
| 19 |
+
assert_async_node = node
|
| 20 |
+
if len(assert_async_node.users) > 0:
|
| 21 |
+
continue
|
| 22 |
+
module.graph.erase_node(assert_async_node)
|
| 23 |
+
# the upstream scalar_tensor <- {le, ge} <- sym_size
|
| 24 |
+
# linear chain of nodes of nodes is removed by the
|
| 25 |
+
# downstream dead code elimination
|
| 26 |
+
modified = True
|
| 27 |
+
return PassResult(graph_module, modified)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._higher_order_ops.wrap import wrap_with_autocast
|
| 6 |
+
|
| 7 |
+
from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split
|
| 8 |
+
from .replace_with_hop_pass_util import (
|
| 9 |
+
_replace_with_hop_helper,
|
| 10 |
+
_replace_with_hop_pass_helper,
|
| 11 |
+
_sequential_split_and_maybe_inline_subgraphs_helper,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _is_autocast_node(node: torch.fx.Node):
|
| 16 |
+
return (
|
| 17 |
+
node
|
| 18 |
+
and node.op == "call_function"
|
| 19 |
+
and node.target
|
| 20 |
+
in [
|
| 21 |
+
torch.amp.autocast_mode._enter_autocast,
|
| 22 |
+
torch.amp.autocast_mode._exit_autocast,
|
| 23 |
+
]
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _is_enter_autocast_node(node: torch.fx.Node):
|
| 28 |
+
return (
|
| 29 |
+
node
|
| 30 |
+
and node.op == "call_function"
|
| 31 |
+
and node.target == torch.amp.autocast_mode._enter_autocast
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _is_exit_autocast_node(node: torch.fx.Node):
|
| 36 |
+
return (
|
| 37 |
+
node
|
| 38 |
+
and node.op == "call_function"
|
| 39 |
+
and node.target == torch.amp.autocast_mode._exit_autocast
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _is_autocast_sub_mod(node: torch.fx.Node):
|
| 44 |
+
"""
|
| 45 |
+
Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`.
|
| 46 |
+
"""
|
| 47 |
+
if node.op == "call_module":
|
| 48 |
+
assert isinstance(node.target, str)
|
| 49 |
+
subgm = getattr(node.graph.owning_module, node.target)
|
| 50 |
+
first_non_ph = nodes_first(
|
| 51 |
+
subgm.graph.nodes, lambda node: node.op != "placeholder"
|
| 52 |
+
)
|
| 53 |
+
if (
|
| 54 |
+
first_non_ph
|
| 55 |
+
and first_non_ph.op == "call_function"
|
| 56 |
+
and first_non_ph.target == torch.amp.autocast_mode._enter_autocast
|
| 57 |
+
):
|
| 58 |
+
# TODO: check if current auto-cast type is the same as the args of
|
| 59 |
+
# _enter_autocast. If so, return False, i.e. do not create a submodule.
|
| 60 |
+
return True
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _check_valid_autocast_block(enter_autocast_node, exit_autocast_node):
|
| 65 |
+
assert _is_enter_autocast_node(enter_autocast_node)
|
| 66 |
+
assert _is_exit_autocast_node(exit_autocast_node)
|
| 67 |
+
assert exit_autocast_node.args[0] == enter_autocast_node
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _replace_with_hop(node: torch.fx.Node):
|
| 71 |
+
assert node.op == "call_module"
|
| 72 |
+
graph: torch.fx.Graph = node.graph
|
| 73 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 74 |
+
assert isinstance(node.target, str)
|
| 75 |
+
sub_gm = getattr(gm, node.target)
|
| 76 |
+
sub_graph = sub_gm.graph
|
| 77 |
+
autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node)
|
| 78 |
+
if len(autocast_nodes) > 0:
|
| 79 |
+
assert len(autocast_nodes) > 1 # need at least an enter node and an exist node
|
| 80 |
+
enter_autocast_node = autocast_nodes[0]
|
| 81 |
+
exit_autocast_node = autocast_nodes[-1]
|
| 82 |
+
_check_valid_autocast_block(enter_autocast_node, exit_autocast_node)
|
| 83 |
+
|
| 84 |
+
_replace_with_hop_helper(
|
| 85 |
+
node, enter_autocast_node, _is_autocast_node, wrap_with_autocast
|
| 86 |
+
)
|
| 87 |
+
sub_graph.erase_node(exit_autocast_node)
|
| 88 |
+
sub_graph.erase_node(enter_autocast_node)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 92 |
+
"""
|
| 93 |
+
split_autocast creates a new graph module that splits the input graph module into multiple submodules
|
| 94 |
+
based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module.
|
| 95 |
+
|
| 96 |
+
Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted
|
| 97 |
+
into a submodule. Nested autocast regions are not splitted.
|
| 98 |
+
`_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well.
|
| 99 |
+
|
| 100 |
+
Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph
|
| 101 |
+
module. Nodes marked with the same number are grouped into the same submodule.
|
| 102 |
+
A # 0
|
| 103 |
+
enter_autocast # 1
|
| 104 |
+
B # 1
|
| 105 |
+
exit_autocast # 1
|
| 106 |
+
C # 2
|
| 107 |
+
enter_autocast # 3
|
| 108 |
+
D # 3
|
| 109 |
+
exit_autocast # 3
|
| 110 |
+
E # 4
|
| 111 |
+
"""
|
| 112 |
+
enter_autocast_node_stack: List[torch.fx.Node] = []
|
| 113 |
+
first_node_after_outer_most_exit: bool = False
|
| 114 |
+
|
| 115 |
+
def node_call_back(node: torch.fx.Node):
|
| 116 |
+
nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit
|
| 117 |
+
if first_node_after_outer_most_exit or (
|
| 118 |
+
len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node)
|
| 119 |
+
):
|
| 120 |
+
assert len(enter_autocast_node_stack) == 0
|
| 121 |
+
first_node_after_outer_most_exit = False
|
| 122 |
+
if _is_enter_autocast_node(node):
|
| 123 |
+
enter_autocast_node_stack.append(node)
|
| 124 |
+
return True
|
| 125 |
+
if _is_exit_autocast_node(node):
|
| 126 |
+
assert len(enter_autocast_node_stack) > 0
|
| 127 |
+
last_enter_autocast_node = enter_autocast_node_stack.pop()
|
| 128 |
+
assert node.args[0] == last_enter_autocast_node
|
| 129 |
+
if len(enter_autocast_node_stack) == 0:
|
| 130 |
+
# next node should be in the next submodule since
|
| 131 |
+
# autocast block ends
|
| 132 |
+
first_node_after_outer_most_exit = True
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
return sequential_split(gm, node_call_back)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _sequential_split_and_maybe_inline_subgraphs(
|
| 139 |
+
gm: torch.fx.GraphModule, graph_signature
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Helper function for replace_autocast_with_hop_pass().
|
| 143 |
+
Split the graph module into multiple subgraphs based on the autocast nodes.
|
| 144 |
+
For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
|
| 145 |
+
back into the parent graph module.
|
| 146 |
+
Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered
|
| 147 |
+
as a subgraph.
|
| 148 |
+
"""
|
| 149 |
+
need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes)
|
| 150 |
+
if not need_replacing:
|
| 151 |
+
return gm, graph_signature
|
| 152 |
+
|
| 153 |
+
# split_autocast returns a new graph module that could have different output
|
| 154 |
+
# args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`.
|
| 155 |
+
new_gm = _split_autocast(gm)
|
| 156 |
+
|
| 157 |
+
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
|
| 158 |
+
if _is_autocast_sub_mod(node):
|
| 159 |
+
_replace_with_hop(node)
|
| 160 |
+
else:
|
| 161 |
+
assert node.op == "call_module"
|
| 162 |
+
assert isinstance(node.target, str)
|
| 163 |
+
node_inline_(node)
|
| 164 |
+
|
| 165 |
+
return _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 166 |
+
new_gm, graph_signature, _maybe_inline_or_replace_with_hop
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def replace_autocast_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
|
| 171 |
+
"""
|
| 172 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 173 |
+
then recursively call itself on each of the submodules.
|
| 174 |
+
"""
|
| 175 |
+
return _replace_with_hop_pass_helper(
|
| 176 |
+
gm,
|
| 177 |
+
graph_signature,
|
| 178 |
+
_sequential_split_and_maybe_inline_subgraphs,
|
| 179 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import operator
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.export._trace
|
| 8 |
+
from torch._ops import OpOverload
|
| 9 |
+
from torch.ao.quantization.fx._decomposed import (
|
| 10 |
+
dequantize_per_channel,
|
| 11 |
+
dequantize_per_tensor,
|
| 12 |
+
quantize_per_tensor,
|
| 13 |
+
)
|
| 14 |
+
from torch.ao.quantization.utils import calculate_qmin_qmax
|
| 15 |
+
from torch.fx.graph_module import _assign_attr
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Those values will need to be carried over multiple operators.
|
| 21 |
+
_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None
|
| 22 |
+
_SCALE: Optional[Union[float, torch.fx.Node]] = None
|
| 23 |
+
_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def int_to_valid_dtype(val: int) -> torch.dtype:
|
| 27 |
+
from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import.
|
| 28 |
+
|
| 29 |
+
if isinstance(val, torch.dtype):
|
| 30 |
+
return val
|
| 31 |
+
dtype = _TORCH_ENUM_TO_DTYPE[val]
|
| 32 |
+
if dtype == torch.quint8:
|
| 33 |
+
return torch.uint8
|
| 34 |
+
elif dtype == torch.qint8:
|
| 35 |
+
return torch.int8
|
| 36 |
+
return dtype
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node:
|
| 40 |
+
return gm.graph.call_function(int_to_valid_dtype, (val,))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def insert_quantized_node(
|
| 44 |
+
gm: torch.fx.GraphModule,
|
| 45 |
+
val_node: torch.fx.Node,
|
| 46 |
+
scale_node: Union[float, torch.fx.Node],
|
| 47 |
+
zero_point_node: Union[float, torch.fx.Node],
|
| 48 |
+
qmin_node: Union[float, int, torch.fx.Node],
|
| 49 |
+
qmax_node: Union[float, int, torch.fx.Node],
|
| 50 |
+
dtype_node: Union[torch.dtype, torch.fx.Node],
|
| 51 |
+
qscheme: Optional[torch.qscheme],
|
| 52 |
+
) -> torch.fx.Node:
|
| 53 |
+
return gm.graph.call_function(
|
| 54 |
+
quantize_per_tensor,
|
| 55 |
+
(
|
| 56 |
+
val_node,
|
| 57 |
+
scale_node,
|
| 58 |
+
zero_point_node,
|
| 59 |
+
qmin_node,
|
| 60 |
+
qmax_node,
|
| 61 |
+
dtype_node,
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_dequantized(
|
| 67 |
+
val: torch.Tensor,
|
| 68 |
+
scale: Union[float, torch.Tensor],
|
| 69 |
+
zero_point: Union[float, torch.Tensor],
|
| 70 |
+
qmin: Union[float, int],
|
| 71 |
+
qmax: Union[float, int],
|
| 72 |
+
dtype: torch.dtype,
|
| 73 |
+
axis: Optional[int],
|
| 74 |
+
qscheme: Optional[torch.qscheme],
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
if qscheme is torch.per_tensor_affine:
|
| 77 |
+
return dequantize_per_tensor(
|
| 78 |
+
val,
|
| 79 |
+
scale,
|
| 80 |
+
zero_point,
|
| 81 |
+
qmin,
|
| 82 |
+
qmax,
|
| 83 |
+
dtype,
|
| 84 |
+
)
|
| 85 |
+
elif qscheme is torch.per_channel_affine:
|
| 86 |
+
return dequantize_per_channel(
|
| 87 |
+
val,
|
| 88 |
+
scale,
|
| 89 |
+
zero_point,
|
| 90 |
+
axis,
|
| 91 |
+
qmin,
|
| 92 |
+
qmax,
|
| 93 |
+
dtype,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def insert_dequantized_node(
|
| 100 |
+
gm: torch.fx.GraphModule,
|
| 101 |
+
val_node: torch.fx.Node,
|
| 102 |
+
scale_node: Union[float, torch.fx.Node],
|
| 103 |
+
zero_point_node: Union[float, torch.fx.Node],
|
| 104 |
+
qmin_node: Union[float, int, torch.fx.Node],
|
| 105 |
+
qmax_node: Union[float, int, torch.fx.Node],
|
| 106 |
+
dtype_node: Union[torch.dtype, torch.fx.Node],
|
| 107 |
+
axis_node: Optional[Union[int, torch.fx.Node]],
|
| 108 |
+
qscheme: Optional[torch.qscheme],
|
| 109 |
+
) -> torch.fx.Node:
|
| 110 |
+
if qscheme is torch.per_tensor_affine:
|
| 111 |
+
return gm.graph.call_function(
|
| 112 |
+
dequantize_per_tensor,
|
| 113 |
+
(
|
| 114 |
+
val_node,
|
| 115 |
+
scale_node,
|
| 116 |
+
zero_point_node,
|
| 117 |
+
qmin_node,
|
| 118 |
+
qmax_node,
|
| 119 |
+
dtype_node,
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
elif qscheme is torch.per_channel_affine:
|
| 123 |
+
return gm.graph.call_function(
|
| 124 |
+
dequantize_per_channel,
|
| 125 |
+
(
|
| 126 |
+
val_node,
|
| 127 |
+
scale_node,
|
| 128 |
+
zero_point_node,
|
| 129 |
+
axis_node,
|
| 130 |
+
qmin_node,
|
| 131 |
+
qmax_node,
|
| 132 |
+
dtype_node,
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_qmin_qmax(dtype: torch.dtype) -> Tuple[Union[int, float], Union[int, float]]:
|
| 140 |
+
return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def insert_qmin_qmax_node(
|
| 144 |
+
gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node]
|
| 145 |
+
) -> Tuple[torch.fx.Node, torch.fx.Node]:
|
| 146 |
+
q_min_max_node = gm.graph.call_function(
|
| 147 |
+
calculate_qmin_qmax, (None, None, False, dtype_node, False)
|
| 148 |
+
)
|
| 149 |
+
qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0))
|
| 150 |
+
qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1))
|
| 151 |
+
return qmin_node, qmax_node
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_script_object(
|
| 155 |
+
gm: torch.nn.Module, node: torch.fx.Node
|
| 156 |
+
) -> torch._C.ScriptObject:
|
| 157 |
+
assert isinstance(node, torch.fx.Node)
|
| 158 |
+
assert node.op == "get_attr"
|
| 159 |
+
attr_name = node.target
|
| 160 |
+
assert isinstance(attr_name, str)
|
| 161 |
+
|
| 162 |
+
mod = gm
|
| 163 |
+
for attr in attr_name.split("."):
|
| 164 |
+
mod = getattr(mod, attr)
|
| 165 |
+
assert isinstance(mod, torch._C.ScriptObject)
|
| 166 |
+
return mod
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 170 |
+
gm: torch.fx.GraphModule,
|
| 171 |
+
param_node: torch.fx.Node,
|
| 172 |
+
) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 173 |
+
"""Directly inline tensor from a get_attr fx node."""
|
| 174 |
+
mod = get_script_object(gm, param_node)
|
| 175 |
+
w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined]
|
| 176 |
+
w_attr_name, b_attr_name = (
|
| 177 |
+
f"dequantized_{param_node.target}_w",
|
| 178 |
+
f"dequantized_{param_node.target}_b",
|
| 179 |
+
)
|
| 180 |
+
return insert_weight_and_bias_get_attr_node(
|
| 181 |
+
gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 186 |
+
gm: torch.fx.GraphModule,
|
| 187 |
+
get_attr_to_weight_node: torch.fx.Node,
|
| 188 |
+
get_attr_to_bias_node: Optional[torch.fx.Node],
|
| 189 |
+
) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 190 |
+
assert isinstance(get_attr_to_weight_node.target, str)
|
| 191 |
+
w_qtensor = getattr(gm, get_attr_to_weight_node.target)
|
| 192 |
+
w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w"
|
| 193 |
+
|
| 194 |
+
if get_attr_to_bias_node is not None:
|
| 195 |
+
assert isinstance(get_attr_to_bias_node.target, str)
|
| 196 |
+
b_qtensor = getattr(gm, get_attr_to_bias_node.target)
|
| 197 |
+
b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b"
|
| 198 |
+
else:
|
| 199 |
+
b_qtensor, b_attr_name = None, ""
|
| 200 |
+
|
| 201 |
+
return insert_weight_and_bias_get_attr_node(
|
| 202 |
+
gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def insert_weight_and_bias_get_attr_node(
|
| 207 |
+
gm: torch.fx.GraphModule,
|
| 208 |
+
w_qtensor: torch.Tensor,
|
| 209 |
+
b_qtensor: Optional[torch.Tensor],
|
| 210 |
+
w_attr_name: str,
|
| 211 |
+
b_attr_name: str,
|
| 212 |
+
) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 213 |
+
w_tensor = get_tensor_from_qtensor(w_qtensor)
|
| 214 |
+
_assign_attr(w_tensor, gm, w_attr_name)
|
| 215 |
+
w_tensor_attr = gm.graph.get_attr(w_attr_name)
|
| 216 |
+
|
| 217 |
+
if b_qtensor is not None:
|
| 218 |
+
b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False)
|
| 219 |
+
_assign_attr(b_tensor, gm, b_attr_name)
|
| 220 |
+
b_tensor_attr = gm.graph.get_attr(b_attr_name)
|
| 221 |
+
else:
|
| 222 |
+
b_tensor_attr = None
|
| 223 |
+
|
| 224 |
+
return w_tensor_attr, b_tensor_attr
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def get_tensor_from_qtensor(
|
| 228 |
+
qtensor: torch.Tensor, dequant: bool = True
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
# Manual conversion because qint8 is not used anymore.
|
| 231 |
+
if qtensor.dtype in [torch.qint8, torch.quint8]:
|
| 232 |
+
tensor = qtensor.int_repr()
|
| 233 |
+
else:
|
| 234 |
+
tensor = qtensor
|
| 235 |
+
|
| 236 |
+
# Weights need dequantization with scaling and zero_point adjustment, but
|
| 237 |
+
# bias does not need that.
|
| 238 |
+
if dequant:
|
| 239 |
+
qscheme = qtensor.qscheme()
|
| 240 |
+
if qscheme == torch.per_channel_affine:
|
| 241 |
+
scale, zero_point, axis = (
|
| 242 |
+
qtensor.q_per_channel_scales(),
|
| 243 |
+
qtensor.q_per_channel_zero_points(),
|
| 244 |
+
qtensor.q_per_channel_axis(),
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
scale, zero_point, axis = (
|
| 248 |
+
qtensor.q_scale(), # type: ignore[assignment]
|
| 249 |
+
qtensor.q_zero_point(), # type: ignore[assignment]
|
| 250 |
+
None,
|
| 251 |
+
)
|
| 252 |
+
dtype = tensor.dtype
|
| 253 |
+
qmin, qmax = get_qmin_qmax(dtype)
|
| 254 |
+
return get_dequantized(
|
| 255 |
+
tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme
|
| 256 |
+
)
|
| 257 |
+
return tensor
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def insert_fused_activation_node(
|
| 261 |
+
gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node
|
| 262 |
+
) -> torch.fx.Node:
|
| 263 |
+
if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]:
|
| 264 |
+
fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,))
|
| 265 |
+
return fx_node
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _conv1d_op_with_squeeze(
|
| 269 |
+
inp: torch.Tensor,
|
| 270 |
+
weight: torch.Tensor,
|
| 271 |
+
bias: Optional[torch.Tensor],
|
| 272 |
+
stride: List[int],
|
| 273 |
+
padding: List[int],
|
| 274 |
+
dilation: List[int],
|
| 275 |
+
groups: int,
|
| 276 |
+
) -> torch.Tensor:
|
| 277 |
+
# In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
|
| 278 |
+
# operations before and after the conv2d operation to match the dimension of weights.
|
| 279 |
+
# Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950
|
| 280 |
+
s_inp = torch.ops.aten.unsqueeze(inp, 2)
|
| 281 |
+
conv1d_res = torch.ops.aten.conv2d(
|
| 282 |
+
s_inp,
|
| 283 |
+
weight,
|
| 284 |
+
bias,
|
| 285 |
+
stride,
|
| 286 |
+
padding,
|
| 287 |
+
dilation,
|
| 288 |
+
groups,
|
| 289 |
+
)
|
| 290 |
+
uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2)
|
| 291 |
+
return uns_conv1d_res
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 295 |
+
"""Conv specfic transformation function."""
|
| 296 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 297 |
+
opname = node.target._opname
|
| 298 |
+
scale_node, zero_point_node = node.args[2], node.args[3]
|
| 299 |
+
|
| 300 |
+
op_f = (
|
| 301 |
+
torch.ops.aten.conv2d
|
| 302 |
+
if opname in ["conv2d", "conv2d_relu"]
|
| 303 |
+
else _conv1d_op_with_squeeze
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
inp_node, param_node = node.args[0], node.args[1]
|
| 307 |
+
assert isinstance(inp_node, torch.fx.Node)
|
| 308 |
+
assert isinstance(param_node, torch.fx.Node)
|
| 309 |
+
|
| 310 |
+
if param_node.op == "call_function":
|
| 311 |
+
# Using Conv2dPrepackParam from conv_prepack.
|
| 312 |
+
# We directly skip the packing call and inline weights and bias.
|
| 313 |
+
w_node, b_node = param_node.args[0], param_node.args[1]
|
| 314 |
+
assert isinstance(w_node, torch.fx.Node)
|
| 315 |
+
assert b_node is None or isinstance(b_node, torch.fx.Node)
|
| 316 |
+
(
|
| 317 |
+
param_0,
|
| 318 |
+
param_1,
|
| 319 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 320 |
+
gm, w_node, b_node
|
| 321 |
+
)
|
| 322 |
+
op_res_node = gm.graph.call_function(
|
| 323 |
+
op_f, (inp_node, param_0, param_1, *param_node.args[2:])
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
# Using ConvPrepackedParam.
|
| 327 |
+
param = get_script_object(gm, param_node)
|
| 328 |
+
(
|
| 329 |
+
param_0,
|
| 330 |
+
param_1,
|
| 331 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 332 |
+
gm, param_node
|
| 333 |
+
) # type: ignore[assignment]
|
| 334 |
+
op_res_node = gm.graph.call_function(
|
| 335 |
+
op_f,
|
| 336 |
+
(
|
| 337 |
+
inp_node,
|
| 338 |
+
param_0,
|
| 339 |
+
param_1,
|
| 340 |
+
param.stride(), # type: ignore[attr-defined]
|
| 341 |
+
param.padding(), # type: ignore[attr-defined]
|
| 342 |
+
param.dilation(), # type: ignore[attr-defined]
|
| 343 |
+
param.groups(), # type: ignore[attr-defined]
|
| 344 |
+
),
|
| 345 |
+
)
|
| 346 |
+
return op_res_node, scale_node, zero_point_node
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 350 |
+
"""Linear specfic transformation function."""
|
| 351 |
+
scale_node, zero_point_node = node.args[2], node.args[3]
|
| 352 |
+
|
| 353 |
+
inp_node, param_node = node.args[0], node.args[1]
|
| 354 |
+
assert isinstance(inp_node, torch.fx.Node)
|
| 355 |
+
assert isinstance(param_node, torch.fx.Node)
|
| 356 |
+
|
| 357 |
+
if param_node.op == "call_function":
|
| 358 |
+
# Using LinearPrepackParam from linear_prepack.
|
| 359 |
+
# We directly skip the packing call and inline weights and bias.
|
| 360 |
+
w_node, b_node = param_node.args[0], param_node.args[1]
|
| 361 |
+
assert isinstance(w_node, torch.fx.Node)
|
| 362 |
+
assert b_node is None or isinstance(b_node, torch.fx.Node)
|
| 363 |
+
(
|
| 364 |
+
param_0,
|
| 365 |
+
param_1,
|
| 366 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 367 |
+
gm, w_node, b_node
|
| 368 |
+
)
|
| 369 |
+
op_res_node = gm.graph.call_function(
|
| 370 |
+
torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:])
|
| 371 |
+
)
|
| 372 |
+
else:
|
| 373 |
+
# Using LinearPackedParams.
|
| 374 |
+
(
|
| 375 |
+
param_0,
|
| 376 |
+
param_1,
|
| 377 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 378 |
+
gm, param_node
|
| 379 |
+
) # type: ignore[assignment]
|
| 380 |
+
op_res_node = gm.graph.call_function(
|
| 381 |
+
torch.ops.aten.linear, (inp_node, param_0, param_1)
|
| 382 |
+
)
|
| 383 |
+
return op_res_node, scale_node, zero_point_node
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _transform_op_where_last_two_arguments_are_scale_and_zero_point(
|
| 387 |
+
gm: torch.fx.GraphModule, node: torch.fx.Node
|
| 388 |
+
):
|
| 389 |
+
"""
|
| 390 |
+
This transformation function can be used for function where the last two
|
| 391 |
+
parameters are scale and zero point. Additionally, the function's parameters
|
| 392 |
+
do not need any unpacking.
|
| 393 |
+
"""
|
| 394 |
+
to_standard_op = {
|
| 395 |
+
"mul": torch.ops.aten.mul,
|
| 396 |
+
"mul_relu": torch.ops.aten.mul,
|
| 397 |
+
"add": torch.ops.aten.add,
|
| 398 |
+
"add_relu": torch.ops.aten.add,
|
| 399 |
+
"softmax": torch.ops.aten.softmax,
|
| 400 |
+
"cat": torch.ops.aten.cat,
|
| 401 |
+
"hardswish": torch.ops.aten.hardswish,
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 405 |
+
opname, args = node.target._opname, node.args
|
| 406 |
+
scale_node, zero_point_node = args[-2], args[-1]
|
| 407 |
+
op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2]))
|
| 408 |
+
return op_res_node, scale_node, zero_point_node
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 412 |
+
"""Transform scalar overload for basic arithmetic."""
|
| 413 |
+
to_standard_op = {
|
| 414 |
+
"mul": torch.ops.aten.mul.Scalar,
|
| 415 |
+
"add": torch.ops.aten.add.Scalar,
|
| 416 |
+
}
|
| 417 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 418 |
+
opname, args = node.target._opname, node.args
|
| 419 |
+
op_res_node = gm.graph.call_function(to_standard_op[opname], args)
|
| 420 |
+
return op_res_node, _SCALE, _ZERO_POINT
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 424 |
+
"""
|
| 425 |
+
Transformation for functions under prepacked namespace, where they share
|
| 426 |
+
the same handling logic that [...]OpContext contains all parameters.
|
| 427 |
+
"""
|
| 428 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 429 |
+
opname, args = node.target._opname, node.args
|
| 430 |
+
op_f = None
|
| 431 |
+
if opname == "conv2d_clamp_run":
|
| 432 |
+
op_f = torch.ops.aten.conv2d
|
| 433 |
+
elif opname == "linear_clamp_run":
|
| 434 |
+
op_f = torch.ops.aten.linear
|
| 435 |
+
else:
|
| 436 |
+
raise RuntimeError(f"Invalid operator {opname}")
|
| 437 |
+
|
| 438 |
+
assert isinstance(args[1], torch.fx.Node)
|
| 439 |
+
so = get_script_object(gm, args[1])
|
| 440 |
+
|
| 441 |
+
func_args = []
|
| 442 |
+
func_args += [args[0]]
|
| 443 |
+
func_args += so.unpack()[:2] # type: ignore[attr-defined]
|
| 444 |
+
if opname == "conv2d_clamp_run":
|
| 445 |
+
func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]
|
| 446 |
+
|
| 447 |
+
op_res_node = gm.graph.call_function(op_f, tuple(func_args))
|
| 448 |
+
return op_res_node
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 452 |
+
args = node.args
|
| 453 |
+
scale_node, zero_point_node = args[-2], args[-1]
|
| 454 |
+
op_res_node = gm.graph.call_function(
|
| 455 |
+
torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3])
|
| 456 |
+
)
|
| 457 |
+
op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0))
|
| 458 |
+
return op_res_node, scale_node, zero_point_node
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def fx_transform_quantized_op_to_standard_op(
|
| 462 |
+
gm: torch.fx.GraphModule, node: torch.fx.Node
|
| 463 |
+
) -> torch.fx.Node:
|
| 464 |
+
global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE
|
| 465 |
+
|
| 466 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 467 |
+
opname, overload = node.target._opname, node.target._overloadname
|
| 468 |
+
|
| 469 |
+
key = f"{opname}.{overload}"
|
| 470 |
+
opname_to_transform_f = {
|
| 471 |
+
"conv1d.new": _transform_conv_with_packedparam,
|
| 472 |
+
"conv1d_relu.new": _transform_conv_with_packedparam,
|
| 473 |
+
"conv1d.default": _transform_conv_with_packedparam,
|
| 474 |
+
"conv1d_relu.default": _transform_conv_with_packedparam,
|
| 475 |
+
"conv2d.new": _transform_conv_with_packedparam,
|
| 476 |
+
"conv2d_relu.new": _transform_conv_with_packedparam,
|
| 477 |
+
"conv2d.default": _transform_conv_with_packedparam,
|
| 478 |
+
"conv2d_relu.default": _transform_conv_with_packedparam,
|
| 479 |
+
"linear.default": _transform_linear_with_packedparam,
|
| 480 |
+
"linear_relu.default": _transform_linear_with_packedparam,
|
| 481 |
+
"add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 482 |
+
"add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 483 |
+
"mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 484 |
+
"mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 485 |
+
"softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 486 |
+
"cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 487 |
+
"hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 488 |
+
"batch_norm2d.default": _transform_batch_norm,
|
| 489 |
+
"mul.Scalar": _transform_scalar_arithmetic,
|
| 490 |
+
"add.Scalar": _transform_scalar_arithmetic,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
if f"{key}" not in opname_to_transform_f:
|
| 494 |
+
raise RuntimeError(f"Unsupported quantized op during transformation: {key}")
|
| 495 |
+
|
| 496 |
+
op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node)
|
| 497 |
+
|
| 498 |
+
# Add fused activation layer.
|
| 499 |
+
op_res_node = insert_fused_activation_node(gm, opname, op_res_node)
|
| 500 |
+
_SCALE, _ZERO_POINT = scale_node, zero_point_node
|
| 501 |
+
|
| 502 |
+
assert _INPUT_Q_DTYPE is not None
|
| 503 |
+
qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE)
|
| 504 |
+
q_fx_node = insert_quantized_node(
|
| 505 |
+
gm,
|
| 506 |
+
op_res_node,
|
| 507 |
+
scale_node,
|
| 508 |
+
zero_point_node,
|
| 509 |
+
qmin_node,
|
| 510 |
+
qmax_node,
|
| 511 |
+
_INPUT_Q_DTYPE,
|
| 512 |
+
torch.per_tensor_affine,
|
| 513 |
+
)
|
| 514 |
+
dq_fx_node = insert_dequantized_node(
|
| 515 |
+
gm,
|
| 516 |
+
q_fx_node,
|
| 517 |
+
scale_node,
|
| 518 |
+
zero_point_node,
|
| 519 |
+
qmin_node,
|
| 520 |
+
qmax_node,
|
| 521 |
+
_INPUT_Q_DTYPE,
|
| 522 |
+
None,
|
| 523 |
+
torch.per_tensor_affine,
|
| 524 |
+
)
|
| 525 |
+
return dq_fx_node
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
| 529 |
+
"""
|
| 530 |
+
Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with
|
| 531 |
+
PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv).
|
| 532 |
+
|
| 533 |
+
Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y
|
| 534 |
+
|
| 535 |
+
After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y
|
| 536 |
+
|
| 537 |
+
(qd == quantized_decomposed library, q = quantize, dq = dequantize)
|
| 538 |
+
^
|
| 539 |
+
|
|
| 540 |
+
getattr(w), getattr(b) from Conv2dParamPrepack
|
| 541 |
+
|
| 542 |
+
During each iteration, the transformation spits out the transformed operator, its quantized output,
|
| 543 |
+
and its dequantized value together. We did this because dequantization need to use the
|
| 544 |
+
scale and zero point parameters from the quantization to recover the approximate original value. After each
|
| 545 |
+
iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear).
|
| 546 |
+
|
| 547 |
+
For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject.
|
| 548 |
+
During the transformation, we unpack those objects, get their dequantized tensor, populate those
|
| 549 |
+
as attributes to the module, and use getattr to access them.
|
| 550 |
+
|
| 551 |
+
One exception in the transformation is conv_prepack and linear_prepack. Those calls pack
|
| 552 |
+
weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls.
|
| 553 |
+
During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the
|
| 554 |
+
quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
|
| 555 |
+
to the operator by converting them to a getattr fx.node.
|
| 556 |
+
|
| 557 |
+
For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
|
| 558 |
+
without the need of doing de/quantization.
|
| 559 |
+
|
| 560 |
+
Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
|
| 561 |
+
data type, which is the same across the entire program, but it only shows up in the very first quantization
|
| 562 |
+
call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
global _INPUT_Q_DTYPE
|
| 566 |
+
|
| 567 |
+
quantized = False
|
| 568 |
+
|
| 569 |
+
last_quantized_node = None
|
| 570 |
+
for node in gm.graph.nodes:
|
| 571 |
+
if isinstance(node.target, OpOverload):
|
| 572 |
+
with gm.graph.inserting_before(node):
|
| 573 |
+
namespace, opname = node.target.namespace, node.target._opname
|
| 574 |
+
if namespace == "quantized" and opname not in [
|
| 575 |
+
"conv_prepack",
|
| 576 |
+
"linear_prepack",
|
| 577 |
+
]:
|
| 578 |
+
quantized = True
|
| 579 |
+
fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
|
| 580 |
+
node.replace_all_uses_with(fx_node)
|
| 581 |
+
last_quantized_node = fx_node
|
| 582 |
+
elif namespace == "prepacked":
|
| 583 |
+
quantized = True
|
| 584 |
+
fx_node = _transform_prepacked_op(gm, node)
|
| 585 |
+
node.replace_all_uses_with(fx_node)
|
| 586 |
+
last_quantized_node = fx_node
|
| 587 |
+
elif namespace == "aten" and opname == "quantize_per_tensor":
|
| 588 |
+
inp_node, scale_node, zero_point_node, dtype_node = node.args
|
| 589 |
+
dtype_node = fx_enum_to_dtype(gm, dtype_node)
|
| 590 |
+
_INPUT_Q_DTYPE = dtype_node
|
| 591 |
+
qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node)
|
| 592 |
+
q_fx_node = insert_quantized_node(
|
| 593 |
+
gm,
|
| 594 |
+
inp_node,
|
| 595 |
+
scale_node,
|
| 596 |
+
zero_point_node,
|
| 597 |
+
qmin_node,
|
| 598 |
+
qmax_node,
|
| 599 |
+
dtype_node,
|
| 600 |
+
torch.per_tensor_affine,
|
| 601 |
+
)
|
| 602 |
+
dq_fx_node = insert_dequantized_node(
|
| 603 |
+
gm,
|
| 604 |
+
q_fx_node,
|
| 605 |
+
scale_node,
|
| 606 |
+
zero_point_node,
|
| 607 |
+
qmin_node,
|
| 608 |
+
qmax_node,
|
| 609 |
+
dtype_node,
|
| 610 |
+
None,
|
| 611 |
+
torch.per_tensor_affine,
|
| 612 |
+
)
|
| 613 |
+
node.replace_all_uses_with(dq_fx_node)
|
| 614 |
+
last_quantized_node = dq_fx_node
|
| 615 |
+
elif namespace == "aten" and opname == "dequantize":
|
| 616 |
+
assert last_quantized_node is not None
|
| 617 |
+
node.replace_all_uses_with(last_quantized_node)
|
| 618 |
+
else:
|
| 619 |
+
last_quantized_node = node
|
| 620 |
+
|
| 621 |
+
# Post-processing again to remove legacy ScriptObjects and quantizated tensors
|
| 622 |
+
# stored as attributes or in the buffer. This is used to clean up the GraphModule
|
| 623 |
+
# to not trigger tracing errors like missing __obj_flatten__ functions.
|
| 624 |
+
def _clean_attr(mod: torch.nn.Module):
|
| 625 |
+
for submod in mod.modules():
|
| 626 |
+
attr_names_to_clean = set()
|
| 627 |
+
for k, v in submod.__dict__.items():
|
| 628 |
+
if isinstance(v, torch.ScriptObject):
|
| 629 |
+
attr_names_to_clean.add(k)
|
| 630 |
+
if k == "_buffers":
|
| 631 |
+
buffer_name_to_clean = set()
|
| 632 |
+
for b_name, b_value in v.items():
|
| 633 |
+
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
|
| 634 |
+
torch.qint8,
|
| 635 |
+
torch.quint8,
|
| 636 |
+
]:
|
| 637 |
+
buffer_name_to_clean.add(b_name)
|
| 638 |
+
for b_name in buffer_name_to_clean:
|
| 639 |
+
v.pop(b_name, None)
|
| 640 |
+
for attr_name in attr_names_to_clean:
|
| 641 |
+
delattr(submod, attr_name)
|
| 642 |
+
|
| 643 |
+
if quantized:
|
| 644 |
+
"""
|
| 645 |
+
TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily
|
| 646 |
+
bypass test cases.
|
| 647 |
+
|
| 648 |
+
The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing
|
| 649 |
+
will throw errors. However, the current way of SetAttr does inplace update to attributes, so
|
| 650 |
+
this pass regard them as dead code and remove them. Below is an example of GraphModule before
|
| 651 |
+
and after the dead code elimination pass.
|
| 652 |
+
|
| 653 |
+
class GraphModule(torch.nn.Module):
|
| 654 |
+
def forward(self, x_1):
|
| 655 |
+
# No stacktrace found for following nodes
|
| 656 |
+
data = self.data; data = None
|
| 657 |
+
data_1 = self.data
|
| 658 |
+
add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None
|
| 659 |
+
data_2 = self.data
|
| 660 |
+
copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None
|
| 661 |
+
data_3 = self.data
|
| 662 |
+
add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
|
| 663 |
+
return add_tensor_1
|
| 664 |
+
|
| 665 |
+
class GraphModule(torch.nn.Module):
|
| 666 |
+
def forward(self, x_1):
|
| 667 |
+
# No stacktrace found for following nodes
|
| 668 |
+
data_3 = self.data
|
| 669 |
+
add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
|
| 670 |
+
return add_tensor_1
|
| 671 |
+
"""
|
| 672 |
+
gm.graph.eliminate_dead_code()
|
| 673 |
+
_clean_attr(gm)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
|
| 5 |
+
|
| 6 |
+
from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split
|
| 7 |
+
from .replace_with_hop_pass_util import (
|
| 8 |
+
_replace_with_hop_helper,
|
| 9 |
+
_replace_with_hop_pass_helper,
|
| 10 |
+
_sequential_split_and_maybe_inline_subgraphs_helper,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _is_set_grad_enabled_node(node: torch.fx.Node):
|
| 15 |
+
return (
|
| 16 |
+
node
|
| 17 |
+
and node.op == "call_function"
|
| 18 |
+
and node.target == torch._C._set_grad_enabled
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False):
|
| 23 |
+
if node.op == "call_module":
|
| 24 |
+
assert isinstance(node.target, str)
|
| 25 |
+
subgm = getattr(node.graph.owning_module, node.target)
|
| 26 |
+
first_non_ph = nodes_first(
|
| 27 |
+
subgm.graph.nodes, lambda node: node.op != "placeholder"
|
| 28 |
+
)
|
| 29 |
+
if (
|
| 30 |
+
first_non_ph
|
| 31 |
+
and first_non_ph.op == "call_function"
|
| 32 |
+
and first_non_ph.target == torch._C._set_grad_enabled
|
| 33 |
+
):
|
| 34 |
+
return (
|
| 35 |
+
first_non_ph.args[0] != torch.is_grad_enabled()
|
| 36 |
+
if omit_if_same_with_ambient
|
| 37 |
+
else True
|
| 38 |
+
)
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _replace_with_hop(node: torch.fx.Node):
|
| 43 |
+
assert node.op == "call_module"
|
| 44 |
+
graph: torch.fx.Graph = node.graph
|
| 45 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 46 |
+
assert isinstance(node.target, str)
|
| 47 |
+
sub_gm = getattr(gm, node.target)
|
| 48 |
+
sub_graph = sub_gm.graph
|
| 49 |
+
set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
|
| 50 |
+
if len(set_grad_nodes) > 0:
|
| 51 |
+
assert len(set_grad_nodes) == 1
|
| 52 |
+
set_grad_node = set_grad_nodes[0]
|
| 53 |
+
_replace_with_hop_helper(
|
| 54 |
+
node, set_grad_node, _is_set_grad_enabled_node, wrap_with_set_grad_enabled
|
| 55 |
+
)
|
| 56 |
+
sub_graph.erase_node(set_grad_node)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _remove_set_grad_and_inline(node: torch.fx.Node):
|
| 60 |
+
assert node.op == "call_module"
|
| 61 |
+
graph: torch.fx.Graph = node.graph
|
| 62 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 63 |
+
assert isinstance(node.target, str)
|
| 64 |
+
sub_gm = getattr(gm, node.target)
|
| 65 |
+
sub_graph = sub_gm.graph
|
| 66 |
+
nodes_map(
|
| 67 |
+
sub_graph.nodes,
|
| 68 |
+
lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
|
| 69 |
+
)
|
| 70 |
+
node_inline_(node)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _sequential_split_and_maybe_inline_subgraphs(
|
| 74 |
+
gm: torch.fx.GraphModule, graph_signature
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Helper function for replace_set_grad_with_hop_pass().
|
| 78 |
+
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
|
| 79 |
+
For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
|
| 80 |
+
back into the parent graph module.
|
| 81 |
+
"""
|
| 82 |
+
need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes)
|
| 83 |
+
if not need_replacing:
|
| 84 |
+
return gm, graph_signature
|
| 85 |
+
|
| 86 |
+
# sequential_split returns a new graph module that could have different output
|
| 87 |
+
# args names. We need to fix the graph signature.
|
| 88 |
+
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
|
| 89 |
+
|
| 90 |
+
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
|
| 91 |
+
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
|
| 92 |
+
_replace_with_hop(node)
|
| 93 |
+
else:
|
| 94 |
+
_remove_set_grad_and_inline(node)
|
| 95 |
+
|
| 96 |
+
return _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 97 |
+
new_gm, graph_signature, _maybe_inline_or_replace_with_hop
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
|
| 102 |
+
"""
|
| 103 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 104 |
+
then recursively call itself on each of the submodules.
|
| 105 |
+
"""
|
| 106 |
+
return _replace_with_hop_pass_helper(
|
| 107 |
+
gm,
|
| 108 |
+
graph_signature,
|
| 109 |
+
_sequential_split_and_maybe_inline_subgraphs,
|
| 110 |
+
)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
import torch
|
| 4 |
+
from torch._ops import OpOverload, HigherOrderOperator
|
| 5 |
+
from torch._export.error import InternalError
|
| 6 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
|
| 13 |
+
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def is_view_op(schema: torch._C.FunctionSchema) -> bool:
|
| 18 |
+
if len(schema.arguments) == 0:
|
| 19 |
+
return False
|
| 20 |
+
alias_info = schema.arguments[0].alias_info
|
| 21 |
+
return (alias_info is not None) and (not alias_info.is_write)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
|
| 25 |
+
if is_view_op(schema) and schema.name.startswith("aten::"):
|
| 26 |
+
view_op_name = schema.name.split("::")[1]
|
| 27 |
+
view_op_overload = (
|
| 28 |
+
schema.overload_name
|
| 29 |
+
if schema.overload_name != ""
|
| 30 |
+
else "default"
|
| 31 |
+
)
|
| 32 |
+
view_copy_op_name = view_op_name + "_copy"
|
| 33 |
+
if not hasattr(torch.ops.aten, view_copy_op_name):
|
| 34 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 35 |
+
|
| 36 |
+
view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
|
| 37 |
+
|
| 38 |
+
if not hasattr(view_copy_op_overload_packet, view_op_overload):
|
| 39 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 40 |
+
|
| 41 |
+
return getattr(view_copy_op_overload_packet, view_op_overload)
|
| 42 |
+
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 47 |
+
"""
|
| 48 |
+
Our backend expects pure functional operators. For efficiency
|
| 49 |
+
purposes, we keep view ops around while functionalizing the exported
|
| 50 |
+
program. This pass replaces view ops with view copy ops for backends that
|
| 51 |
+
need AOT memory planning.
|
| 52 |
+
"""
|
| 53 |
+
def call_operator(self, op, args, kwargs, meta):
|
| 54 |
+
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
|
| 55 |
+
return super().call_operator(
|
| 56 |
+
(_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if isinstance(op, HigherOrderOperator):
|
| 60 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 61 |
+
|
| 62 |
+
if view_copy_op := get_view_copy_of_view_op(op._schema):
|
| 63 |
+
return super().call_operator(view_copy_op, args, kwargs, meta)
|
| 64 |
+
|
| 65 |
+
return super().call_operator(op, args, kwargs, meta)
|
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
|
| 3 |
+
import contextlib
|
| 4 |
+
import copy
|
| 5 |
+
import operator
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch._ops import HigherOrderOperator
|
| 10 |
+
|
| 11 |
+
from ..utils import node_replace_, nodes_map
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _replace_with_hop_helper(
|
| 15 |
+
node: torch.fx.Node,
|
| 16 |
+
enter_block_node: torch.fx.Node,
|
| 17 |
+
node_filter: Callable,
|
| 18 |
+
wrap_hoo: HigherOrderOperator,
|
| 19 |
+
):
|
| 20 |
+
graph: torch.fx.Graph = node.graph
|
| 21 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 22 |
+
assert isinstance(node.target, str)
|
| 23 |
+
sub_gm = getattr(gm, node.target)
|
| 24 |
+
|
| 25 |
+
def set_hoo_node_meta(call_func_node):
|
| 26 |
+
call_func_node.meta["nn_module_stack"] = copy.copy(
|
| 27 |
+
enter_block_node.meta.get("nn_module_stack", {})
|
| 28 |
+
)
|
| 29 |
+
call_func_node.meta["torch_fn"] = (
|
| 30 |
+
f"{wrap_hoo.__name__}",
|
| 31 |
+
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
|
| 32 |
+
)
|
| 33 |
+
if isinstance(output_args, (tuple, list)):
|
| 34 |
+
call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args)
|
| 35 |
+
elif isinstance(output_args, torch.fx.Node):
|
| 36 |
+
call_func_node.meta["val"] = (output_args.meta["val"],)
|
| 37 |
+
|
| 38 |
+
with graph.inserting_before(node):
|
| 39 |
+
get_attr_node = graph.get_attr(node.target)
|
| 40 |
+
get_attr_node.meta["nn_module_stack"] = copy.copy(
|
| 41 |
+
enter_block_node.meta.get("nn_module_stack", {})
|
| 42 |
+
)
|
| 43 |
+
output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
|
| 44 |
+
# Split_module pass intentially doesn't add output node
|
| 45 |
+
# if the graph doesn't return anything.
|
| 46 |
+
# TODO (tmanlaibaatar) Figure out if this is right behaviour
|
| 47 |
+
# for split_module
|
| 48 |
+
if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
|
| 49 |
+
output_node = None
|
| 50 |
+
if output_node is not None:
|
| 51 |
+
assert len(output_node.args) == 1
|
| 52 |
+
output_args = output_node.args[0]
|
| 53 |
+
enter_block_node_args = enter_block_node.args
|
| 54 |
+
if isinstance(output_args, (tuple, list)):
|
| 55 |
+
call_func_node = graph.call_function(
|
| 56 |
+
wrap_hoo,
|
| 57 |
+
(*enter_block_node_args, get_attr_node, *node.args),
|
| 58 |
+
{},
|
| 59 |
+
)
|
| 60 |
+
# Create the metadata
|
| 61 |
+
set_hoo_node_meta(call_func_node)
|
| 62 |
+
node_replace_(node, call_func_node)
|
| 63 |
+
|
| 64 |
+
# Rename the name of getitem nodes to the actual name of its contents
|
| 65 |
+
# for passing verifier and better readability, also propagate metadata
|
| 66 |
+
for get_item_node in call_func_node.users.keys():
|
| 67 |
+
idx: int = get_item_node.args[1] # type: ignore[assignment]
|
| 68 |
+
output_node = output_args[idx]
|
| 69 |
+
get_item_node._rename(output_node.name)
|
| 70 |
+
get_item_node.meta = output_node.meta
|
| 71 |
+
|
| 72 |
+
elif isinstance(output_args, torch.fx.Node):
|
| 73 |
+
call_func_node = graph.create_node(
|
| 74 |
+
"call_function",
|
| 75 |
+
wrap_hoo,
|
| 76 |
+
(*enter_block_node_args, get_attr_node, *node.args),
|
| 77 |
+
{},
|
| 78 |
+
output_args.name,
|
| 79 |
+
)
|
| 80 |
+
# Modify the subgraph to output a singleton list.
|
| 81 |
+
output_node.args = ((output_args,),)
|
| 82 |
+
# Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph.
|
| 83 |
+
get_item_node = graph.create_node(
|
| 84 |
+
"call_function",
|
| 85 |
+
operator.getitem,
|
| 86 |
+
(call_func_node, 0),
|
| 87 |
+
{},
|
| 88 |
+
)
|
| 89 |
+
# Create the metadata
|
| 90 |
+
get_item_node.meta = output_args.meta
|
| 91 |
+
set_hoo_node_meta(call_func_node)
|
| 92 |
+
node_replace_(node, get_item_node)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError(
|
| 95 |
+
f"repalce_with_hop_pass doesnt' support output type {type(output_args)}"
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
# TODO (shangdiy): remove this line, since the export graph can be non-functional
|
| 99 |
+
node.graph.erase_node(node)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 103 |
+
new_gm: torch.fx.GraphModule,
|
| 104 |
+
graph_signature,
|
| 105 |
+
maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None],
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Helper function for replacing graph nodse with higher order nodes.
|
| 109 |
+
For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls
|
| 110 |
+
back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`.
|
| 111 |
+
"""
|
| 112 |
+
# new_gm is a new graph module that could have different output args names.
|
| 113 |
+
# We need to fix the graph signature.
|
| 114 |
+
replace_ctx = contextlib.nullcontext()
|
| 115 |
+
new_signature = None
|
| 116 |
+
if graph_signature is not None:
|
| 117 |
+
# Cannot deep copy a real ScriptObject, which is referenced
|
| 118 |
+
# in the FakeScriptObject. Copy should be good enough to guard
|
| 119 |
+
# against accidental mutation to original graph_signature.
|
| 120 |
+
new_signature = copy.copy(graph_signature)
|
| 121 |
+
new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output")))
|
| 122 |
+
assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len(
|
| 123 |
+
new_signature.output_specs
|
| 124 |
+
)
|
| 125 |
+
for arg_node, out_spec in zip(
|
| 126 |
+
new_gm_out_node.args[0], new_signature.output_specs
|
| 127 |
+
):
|
| 128 |
+
if arg_node is None:
|
| 129 |
+
assert out_spec.arg.value is None
|
| 130 |
+
elif (
|
| 131 |
+
isinstance(arg_node, torch.fx.Node)
|
| 132 |
+
and out_spec.arg.name != arg_node.name
|
| 133 |
+
):
|
| 134 |
+
out_spec.arg.name = arg_node.name
|
| 135 |
+
|
| 136 |
+
replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment]
|
| 137 |
+
|
| 138 |
+
with replace_ctx:
|
| 139 |
+
nodes_map(
|
| 140 |
+
list(new_gm.graph.nodes),
|
| 141 |
+
lambda node: (
|
| 142 |
+
maybe_inline_or_replace_with_hop(node)
|
| 143 |
+
if node.op == "call_module"
|
| 144 |
+
else node
|
| 145 |
+
),
|
| 146 |
+
)
|
| 147 |
+
new_gm.recompile()
|
| 148 |
+
return new_gm, new_signature
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _replace_with_hop_pass_helper(
|
| 152 |
+
gm: torch.fx.GraphModule,
|
| 153 |
+
graph_signature,
|
| 154 |
+
sequential_split_and_maybe_inline_subgraphs: Callable,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 158 |
+
then recursively call itself on each of the submodules.
|
| 159 |
+
"""
|
| 160 |
+
new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs(
|
| 161 |
+
gm, graph_signature
|
| 162 |
+
)
|
| 163 |
+
# recursively call
|
| 164 |
+
for node in new_gm.graph.nodes:
|
| 165 |
+
if node.op == "get_attr":
|
| 166 |
+
subgm = getattr(new_gm, node.target)
|
| 167 |
+
if not isinstance(subgm, torch.fx.GraphModule):
|
| 168 |
+
continue
|
| 169 |
+
new_subgm, _ = _replace_with_hop_pass_helper(
|
| 170 |
+
subgm,
|
| 171 |
+
None,
|
| 172 |
+
sequential_split_and_maybe_inline_subgraphs,
|
| 173 |
+
)
|
| 174 |
+
setattr(new_gm, node.target, new_subgm)
|
| 175 |
+
|
| 176 |
+
new_gm.recompile()
|
| 177 |
+
new_gm.graph.lint()
|
| 178 |
+
return new_gm, new_signature
|