koichi12 commited on
Commit
7469295
·
verified ·
1 Parent(s): f681997

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc +0 -0
  3. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py +1 -0
  26. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py +80 -0
  40. .venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +227 -0
  41. .venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py +102 -0
  42. .venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py +299 -0
  43. .venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +94 -0
  44. .venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py +318 -0
  45. .venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py +27 -0
  46. .venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py +179 -0
  47. .venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +673 -0
  48. .venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +110 -0
  49. .venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +65 -0
  50. .venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py +178 -0
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/converter.cpython-311.pyc ADDED
Binary file (82.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/error.cpython-311.pyc ADDED
Binary file (2.78 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-311.pyc ADDED
Binary file (24.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/pass_base.cpython-311.pyc ADDED
Binary file (27.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/tools.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/utils.cpython-311.pyc ADDED
Binary file (43.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/verifier.cpython-311.pyc ADDED
Binary file (25.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/__pycache__/wrappers.cpython-311.pyc ADDED
Binary file (7.99 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (189 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/case.cpython-311.pyc ADDED
Binary file (8.44 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/gen_example.cpython-311.pyc ADDED
Binary file (1.36 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/__pycache__/logging.cpython-311.pyc ADDED
Binary file (1.74 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.38 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-311.pyc ADDED
Binary file (1.56 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-311.pyc ADDED
Binary file (1.27 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-311.pyc ADDED
Binary file (1.34 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-311.pyc ADDED
Binary file (1.11 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-311.pyc ADDED
Binary file (2.01 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/optional_input.cpython-311.pyc ADDED
Binary file (1.13 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/scalar_output.cpython-311.pyc ADDED
Binary file (1.52 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_for_loop.cpython-311.pyc ADDED
Binary file (1.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/static_if.cpython-311.pyc ADDED
Binary file (1.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/db/examples/__pycache__/type_reflection_method.cpython-311.pyc ADDED
Binary file (1.49 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (305 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/_node_metadata_hook.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-311.pyc ADDED
Binary file (5.73 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-311.pyc ADDED
Binary file (5.28 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/lift_constants_pass.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-311.pyc ADDED
Binary file (1.63 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-311.pyc ADDED
Binary file (8.71 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-311.pyc ADDED
Binary file (28.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-311.pyc ADDED
Binary file (5.98 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-311.pyc ADDED
Binary file (3.96 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-311.pyc ADDED
Binary file (8.48 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/passes/_node_metadata_hook.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+
4
+ import torch
5
+ from torch.fx.graph_module import GraphModule
6
+
7
+
8
+ _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
9
+
10
+
11
+ def _node_metadata_hook(node: torch.fx.Node, stack_trace: str) -> None:
12
+ """
13
+ Hook for adding the appropriate metadata to nodes that are created during a
14
+ pass using graph.create_node. An example of how to use it:
15
+
16
+ ```
17
+ with _set_node_metadata_hook(gm,
18
+ functools.partial(_node_metadata_hook, stack_trace="file")
19
+ ):
20
+ pass(gm)
21
+ ```
22
+
23
+ This hook should not work for all generic cases -- specifically it assumes
24
+ that nodes being added are only call_function nodes, and copies over the
25
+ first argument node's nn_module_stack.
26
+ """
27
+ assert node.op == "call_function" and callable(node.target)
28
+
29
+ arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
30
+ assert len(arg_meta) >= 1
31
+ arg_meta = arg_meta[0]
32
+
33
+ if (
34
+ isinstance(node.target, torch._ops.OpOverload)
35
+ and len(node.target._schema.returns) == 0
36
+ ):
37
+ node.meta["val"] = None
38
+ else:
39
+ fake_args = [
40
+ arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
41
+ for arg in node.args
42
+ ]
43
+ fake_res = node.target(*fake_args)
44
+ node.meta["val"] = fake_res
45
+
46
+ node.meta["stack_trace"] = stack_trace
47
+ node.meta["nn_module_stack"] = arg_meta.get(
48
+ "nn_module_stack",
49
+ {
50
+ _EMPTY_NN_MODULE_STACK_KEY: (
51
+ _EMPTY_NN_MODULE_STACK_KEY,
52
+ _EMPTY_NN_MODULE_STACK_KEY,
53
+ )
54
+ },
55
+ )
56
+ node.meta["torch_fn"] = (
57
+ f"{node.target.__name__}_0",
58
+ f"{node.target.__class__.__name__}.{node.target.__name__}",
59
+ )
60
+
61
+
62
+ @contextlib.contextmanager
63
+ def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
64
+ """
65
+ Takes a callable which will be called after we create a new node. The
66
+ callable takes the newly created node as input and returns None.
67
+ """
68
+ assert callable(f), "node_metadata_hook must be a callable."
69
+
70
+ # Add the hook to all submodules
71
+ for m in gm.modules():
72
+ if isinstance(m, GraphModule):
73
+ m._register_create_node_hook(f)
74
+ try:
75
+ yield
76
+ finally:
77
+ # Restore hook for all submodules
78
+ for m in gm.modules():
79
+ if isinstance(m, GraphModule):
80
+ m._unregister_create_node_hook(f)
.venv/lib/python3.11/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ import operator
4
+ import traceback
5
+ from functools import partial
6
+ from typing import Callable, Dict, List, NamedTuple, Set
7
+
8
+ import sympy
9
+
10
+ import torch
11
+ import torch.fx
12
+ from torch.utils._sympy.value_ranges import ValueRanges
13
+ from torch.utils._sympy.numbers import int_oo
14
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
15
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
16
+
17
+ __all__ = ["InputDim"]
18
+
19
+
20
+ class InputDim(NamedTuple):
21
+ input_name: str
22
+ dim: int
23
+
24
+
25
+ def _convert_to_int(val):
26
+ # Convert simple sympy Integers into concrete int
27
+ if val in (sympy.oo, int_oo):
28
+ return math.inf
29
+ if val in (-sympy.oo, -int_oo):
30
+ return -math.inf
31
+ if isinstance(val, sympy.Integer):
32
+ return int(val)
33
+ raise RuntimeError(
34
+ "Export constraints cannot be non-integer expressions"
35
+ )
36
+
37
+
38
+ def _convert_range_to_int(range: ValueRanges):
39
+ assert isinstance(range, ValueRanges)
40
+ min_val = _convert_to_int(range.lower)
41
+ max_val = _convert_to_int(range.upper)
42
+ return min_val, max_val
43
+
44
+
45
+ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
46
+ def __init__(
47
+ self,
48
+ range_constraints: Dict[sympy.Symbol, ValueRanges],
49
+ ):
50
+ super().__init__()
51
+ self.range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints
52
+ self._asserts_generated_unbacked_symbols: Set[sympy.Symbol] = set()
53
+ self.counter = 0
54
+
55
+ def _assert_range_constraint(self, node, lower, upper, assert_msg):
56
+ last_node = node
57
+ if lower > -math.inf:
58
+ last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg)
59
+
60
+ if upper < math.inf:
61
+ last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg)
62
+
63
+ def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
64
+ """
65
+ Inserts assert_async call_function nodes in the graph. This function is
66
+ called **during** the interpreter-based pass.
67
+ """
68
+ self.counter += 1
69
+ graph = last_node.graph
70
+ with graph.inserting_after(last_node):
71
+ cmp = graph.call_function(op, (lower, upper), {})
72
+ with graph.inserting_after(cmp):
73
+ cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {})
74
+ with graph.inserting_after(cmp_tensor):
75
+ assert_async = graph.call_function(
76
+ torch.ops.aten._assert_async.msg,
77
+ (cmp_tensor, assert_msg),
78
+ {},
79
+ )
80
+ return assert_async
81
+
82
+ def call(self, graph_module) -> PassResult:
83
+ self.existing_inline_assertions = _get_existing_inline_assertions(
84
+ graph_module, self.range_constraints
85
+ )
86
+
87
+ for module in graph_module.modules():
88
+ if not isinstance(module, torch.fx.GraphModule):
89
+ continue
90
+ for node in module.graph.nodes:
91
+ if node.op != "call_function":
92
+ continue
93
+ if "val" not in node.meta:
94
+ continue
95
+
96
+ val = node.meta["val"]
97
+ # In general, we may have to deal the case such as: ret[1].shape[0].
98
+ # We need first find out what symbols require assertion, then we need to follow the path
99
+ # from ret to the symbol, construct the proxies along the way and construct the messages
100
+ # piece-wise at the same time.
101
+ #
102
+ # We use post-order traversal to collect all the proxies callbacks needed, construct
103
+ # the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
104
+ # We need the callbacks because, in order to call the function to create a proxy for shape[0], we
105
+ # need the proxy for shape, which further requires the proxy for ret[1], etc.
106
+
107
+ def add_assertions(val):
108
+ call_backs: List[Callable] = []
109
+ messages: List[str] = []
110
+ if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
111
+ symbol = val.node.expr
112
+ if symbol in self.existing_inline_assertions:
113
+ return call_backs, messages
114
+ if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
115
+ if symbol in self._asserts_generated_unbacked_symbols:
116
+ return call_backs, messages
117
+ # We only care about unbacked symints for these inline
118
+ # constraints, which are prefixed with 'u'
119
+ constraint = self.range_constraints[symbol]
120
+ min_val, max_val = _convert_range_to_int(constraint)
121
+ assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
122
+ call_backs.append(
123
+ partial(self._assert_range_constraint, lower=min_val, upper=max_val)
124
+ )
125
+ messages.append(assert_msg)
126
+ self._asserts_generated_unbacked_symbols.add(symbol)
127
+
128
+ elif isinstance(val, torch.Tensor):
129
+ for i, sym in enumerate(val.shape):
130
+ cbs, msgs = add_assertions(sym)
131
+ for cb, msg in zip(cbs, msgs):
132
+ def sym_size_cb(node, assert_msg, dim):
133
+ with node.graph.inserting_after(node):
134
+ dim_node = module.graph.call_function(
135
+ torch.ops.aten.sym_size.int,
136
+ (node, dim),
137
+ {},
138
+ )
139
+ cb(node=dim_node, assert_msg=assert_msg)
140
+ call_backs.append(partial(sym_size_cb, dim=i))
141
+ messages.append(f".shape[{i}]" + msg)
142
+ return call_backs, messages
143
+
144
+ callbacks, messages = add_assertions(val)
145
+ for cb, msg in zip(callbacks, messages):
146
+ cb(node=node, assert_msg=f"{node}" + msg)
147
+
148
+ module.recompile()
149
+
150
+ # Sometimes this pass would return a wrong graph where we have mismatched
151
+ # node names in signature. Before we fix it, let's just skip it.
152
+ if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
153
+ return PassResult(graph_module, False)
154
+
155
+ # Populate the stack trace with dummy vals to respect IR
156
+ for node in graph_module.graph.nodes:
157
+ if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]:
158
+ node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
159
+ return PassResult(graph_module, True)
160
+
161
+
162
+ def _get_existing_inline_assertions(
163
+ graph_module: torch.fx.GraphModule,
164
+ range_constraints: Dict[sympy.Symbol, ValueRanges],
165
+ ) -> Dict[sympy.Symbol, ValueRanges]:
166
+ existing_inline_assertions: Dict[sympy.Symbol, ValueRanges] = {}
167
+
168
+ for module in graph_module.modules():
169
+ if not isinstance(module, torch.fx.GraphModule):
170
+ continue
171
+
172
+ # Find all the existing inline assertions. They will look something like:
173
+ # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
174
+ # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
175
+ # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {})
176
+ for node in module.graph.nodes:
177
+ if node.target != torch.ops.aten._assert_scalar.default:
178
+ continue
179
+
180
+ compare_arg = node.args[0]
181
+ if not (
182
+ isinstance(compare_arg, torch.fx.Node) and
183
+ compare_arg.op == "call_function" and
184
+ compare_arg.target in (operator.le, operator.ge) and
185
+ len(compare_arg.args) == 2
186
+ ):
187
+ continue
188
+
189
+ compare_op = compare_arg.target
190
+ lhs, rhs = compare_arg.args
191
+
192
+ def maybe_get_symint(x):
193
+ if (
194
+ isinstance(x, torch.fx.Node) and
195
+ "val" in x.meta and
196
+ isinstance(x.meta["val"], torch.SymInt)
197
+ ):
198
+ return x.meta["val"].node.expr
199
+ return x
200
+
201
+ lhs = maybe_get_symint(lhs)
202
+ rhs = maybe_get_symint(rhs)
203
+
204
+ if compare_op == operator.ge:
205
+ lhs, rhs = rhs, lhs
206
+
207
+ if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int):
208
+ symint = lhs
209
+ scalar = rhs
210
+ elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int):
211
+ symint = rhs
212
+ scalar = lhs
213
+ else:
214
+ continue
215
+
216
+ if symint not in range_constraints:
217
+ raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
218
+
219
+ previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
220
+
221
+ if symint is lhs:
222
+ bounds = ValueRanges(-math.inf, scalar)
223
+ else:
224
+ bounds = ValueRanges(scalar, math.inf)
225
+ existing_inline_assertions[symint] = previous_range & bounds
226
+
227
+ return existing_inline_assertions
.venv/lib/python3.11/site-packages/torch/_export/passes/collect_tracepoints_pass.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import operator
3
+
4
+ import torch
5
+ from torch.export.exported_program import ConstantArgument, TensorArgument
6
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
7
+
8
+
9
+ __all__ = ["CollectTracepointsPass"]
10
+
11
+
12
+ class CollectTracepointsPass(PassBase):
13
+ """
14
+ Performs constant folding and constant propagation.
15
+ """
16
+
17
+ def __init__(self, specs, sig) -> None:
18
+ super().__init__()
19
+ self.specs = specs
20
+ self.sig = sig
21
+
22
+ def call(self, gm):
23
+ def get_arg_spec(arg):
24
+ if isinstance(arg, torch.fx.Node):
25
+ if isinstance(arg.meta.get("val"), torch.Tensor):
26
+ return TensorArgument(name=arg.name)
27
+ else:
28
+ raise AssertionError(
29
+ "Symint input is not implemented yet for submodule call signature."
30
+ )
31
+ else:
32
+ return ConstantArgument(name="", value=arg)
33
+
34
+ for module in gm.modules():
35
+ if not isinstance(module, torch.fx.GraphModule):
36
+ continue
37
+ nn_module_stack = None
38
+ for node in module.graph.nodes:
39
+ if node.op != "call_function":
40
+ continue
41
+ if node.target == torch.ops.higher_order._export_tracepoint:
42
+ kind = node.kwargs["kind"]
43
+ if kind == "module_call_outputs":
44
+ nn_module_stack = node.meta["nn_module_stack"]
45
+ elif kind == "module_call_inputs":
46
+ nn_module_stack = None
47
+ else:
48
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
49
+ elif node.meta["nn_module_stack"] == nn_module_stack:
50
+ node.meta["nn_module_stack"].popitem()
51
+ else:
52
+ nn_module_stack = None
53
+ nn_module_stack = None
54
+ for node in reversed(module.graph.nodes):
55
+ if node.op != "call_function":
56
+ continue
57
+ if node.target == torch.ops.higher_order._export_tracepoint:
58
+ kind = node.kwargs["kind"]
59
+ if kind == "module_call_inputs":
60
+ nn_module_stack = node.meta["nn_module_stack"]
61
+ elif kind == "module_call_outputs":
62
+ nn_module_stack = None
63
+ else:
64
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
65
+ elif node.meta["nn_module_stack"] == nn_module_stack:
66
+ node.meta["nn_module_stack"].popitem()
67
+ else:
68
+ nn_module_stack = None
69
+ for module in gm.modules():
70
+ if not isinstance(module, torch.fx.GraphModule):
71
+ continue
72
+ for node in module.graph.nodes:
73
+ if node.op != "call_function":
74
+ continue
75
+ if node.target == torch.ops.higher_order._export_tracepoint:
76
+ for i, arg in enumerate(node.args):
77
+ kind = node.kwargs["kind"]
78
+ if kind == "module_call_inputs":
79
+ self.specs[node.kwargs["path"]].inputs.append(
80
+ get_arg_spec(arg)
81
+ )
82
+ elif kind == "module_call_outputs":
83
+ self.specs[node.kwargs["path"]].outputs.append(
84
+ get_arg_spec(arg)
85
+ )
86
+ else:
87
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
88
+ if isinstance(arg, torch.fx.Node):
89
+ for user in node.users:
90
+ assert user.op == "call_function"
91
+ assert user.target == operator.getitem
92
+ assert isinstance(user.args[1], int)
93
+ if user.args[1] == i:
94
+ user.replace_all_uses_with(arg)
95
+ self.sig.replace_all_uses(user.name, arg.name)
96
+ break
97
+ users = list(node.users)
98
+ for user in users:
99
+ assert len(user.users) == 0
100
+ gm.graph.erase_node(user)
101
+ gm.graph.erase_node(node)
102
+ return PassResult(gm, True)
.venv/lib/python3.11/site-packages/torch/_export/passes/constant_folding.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ from collections import defaultdict
4
+ from typing import Any, Callable, Dict, Optional
5
+
6
+ import torch
7
+ import torch.utils._pytree as pytree
8
+
9
+
10
+ aten = torch.ops.aten
11
+
12
+ # We would like to split modules into two subgraphs for runtime weight updates to work correctly.
13
+ # The use case and more information could be found at:
14
+ # https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
15
+ META_TAG = "MODULE_TYPE"
16
+ MODULE_TAG = "_MAIN_MODULE"
17
+ CONST_MODULE_TAG = "_CONST_MODULE"
18
+
19
+
20
+ def replace_node_with_constant(gm, node, constant, name=None):
21
+ g = gm.graph
22
+
23
+ if name:
24
+ qualname = name
25
+ else:
26
+ if not hasattr(gm, "_frozen_param_count"):
27
+ gm._frozen_param_count = 0
28
+ i = gm._frozen_param_count
29
+
30
+ while True:
31
+ qualname = f"_frozen_param{i}"
32
+ if not hasattr(gm, qualname):
33
+ break
34
+ i += 1
35
+
36
+ gm._frozen_param_count = i + 1
37
+
38
+ with g.inserting_before(node):
39
+ new_input_node = g.create_node("get_attr", qualname, (), {})
40
+ node.replace_all_uses_with(new_input_node)
41
+ new_input_node.meta.update(node.meta)
42
+ g.erase_node(node)
43
+
44
+ # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
45
+ gm.register_buffer(qualname, constant)
46
+ setattr(gm, qualname, constant)
47
+
48
+
49
+ class ConstantFolder(torch.fx.Interpreter):
50
+ def __init__(
51
+ self,
52
+ gm,
53
+ skip_constructors=False,
54
+ ):
55
+ super().__init__(gm)
56
+ self.node_replacements: Dict[torch.fx.Node, Any] = {}
57
+ self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
58
+ self.unknown_value = object()
59
+ self.skip_constructors: bool = skip_constructors
60
+
61
+ # overwrite this to deallocate env values if their only remaining use
62
+ # is the output
63
+ self.user_to_last_uses = self.node_to_last_non_output_use()
64
+
65
+ def is_impure(self, node: torch.fx.node.Node):
66
+ if (
67
+ node.target == torch.ops.prims.convert_element_type.default
68
+ and node.args[0].op == "get_attr" # type: ignore[union-attr]
69
+ and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
70
+ and node.args[1] == torch.bfloat16
71
+ ):
72
+ # For int8_weight -> dq -> bf16_weight
73
+ return True
74
+ if node.target in [
75
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
76
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
77
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
78
+ ]:
79
+ # For the pattern fp32_weight -> q -> dq
80
+ # We only folding fp32_weight -> q
81
+ # int8_weight and leave dq in graph to be fused
82
+ return True
83
+ return False
84
+
85
+ def node_to_last_non_output_use(self):
86
+ last_non_output_use = collections.defaultdict(list)
87
+ seen_uses = set()
88
+ output_node = next(iter(reversed(self.module.graph.nodes)))
89
+
90
+ for node in reversed(self.module.graph.nodes):
91
+ if node.target == "output":
92
+ continue
93
+
94
+ def add_use(inp):
95
+ if inp in seen_uses:
96
+ return
97
+
98
+ seen_uses.add(inp)
99
+ last_non_output_use[node].append(inp)
100
+
101
+ # In-place is fine since we don't mutate
102
+ pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
103
+
104
+ # if this node is only used in output, we want to gc it right away
105
+ if len(node.users) == 1 and output_node in node.users:
106
+ last_non_output_use[node].append(node)
107
+
108
+ return last_non_output_use
109
+
110
+ def run_node(self, node):
111
+ if node.target == "output":
112
+ # because we remove nodes from env on last non output use,
113
+ # re-define them now or we'll get error in interpreter
114
+ def set_env(arg):
115
+ self.env[arg] = self.unknown_value
116
+
117
+ # In-place is fine since we don't mutate
118
+ pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
119
+ return super().run_node(node)
120
+
121
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
122
+ flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
123
+
124
+ # We need to do this weird thing because in cases where flattened_inputs
125
+ # contains a ScriptObject, equality checking results in a type error if
126
+ # the types are different.
127
+ if any(
128
+ type(self.unknown_value) == type(input_) and self.unknown_value == input_
129
+ for input_ in flattened_inputs
130
+ ):
131
+ return self.unknown_value
132
+
133
+ # TODO - fix errors with this
134
+ if (
135
+ node.op == "call_function"
136
+ and node.target == aten._efficientzerotensor.default
137
+ ):
138
+ return self.unknown_value
139
+
140
+ # TODO - constant folding triton kernel returns the inputs -- fix this
141
+ if (
142
+ node.op == "call_function"
143
+ and node.name == "triton_kernel_wrapper_functional_proxy"
144
+ ):
145
+ return self.unknown_value
146
+
147
+ # skip constructors, since inductor generates optimal code for them already
148
+ # and turning into tensor would result in an additional global memory read
149
+ # TODO - more complicated strategy
150
+ if (
151
+ self.skip_constructors
152
+ and node.op != "get_attr"
153
+ and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
154
+ ):
155
+ return self.unknown_value
156
+
157
+ # All mutations should either be removed or on inputs which we did not make constant
158
+ if (
159
+ isinstance(node.target, torch._ops.OpOverload)
160
+ and torch.Tag.nondeterministic_seeded in node.target.tags
161
+ ):
162
+ return self.unknown_value
163
+
164
+ out = super().run_node(node)
165
+
166
+ if node.op != "get_attr" and isinstance(out, torch.Tensor):
167
+ if out.device.type == "meta":
168
+ return out
169
+
170
+ if not self.insertable_tensor_check(out):
171
+ return out
172
+
173
+ if self.is_impure(node):
174
+ return self.unknown_value
175
+
176
+ self.add_node_replacement(node, out)
177
+
178
+ flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
179
+
180
+ for n in flattened_node_inps:
181
+ if not isinstance(n, torch.fx.Node):
182
+ continue
183
+
184
+ self.replaced_uses[n] += 1
185
+
186
+ for to_delete in self.user_to_last_uses.get(node, []):
187
+ if self.replaced_uses[to_delete] == len(to_delete.users):
188
+ self.node_replacements.pop(to_delete, None)
189
+
190
+ return out
191
+
192
+ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
193
+ return True
194
+
195
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
196
+ self.node_replacements[node] = tensor
197
+
198
+ def run(self):
199
+ env = {}
200
+ for n in self.module.graph.find_nodes(op="placeholder"):
201
+ env[n] = self.unknown_value
202
+ return super().run(initial_env=env)
203
+
204
+
205
+ def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
206
+ with torch.utils._python_dispatch._disable_current_modes():
207
+ cf = ConstantFolder(gm, skip_constructors=True)
208
+ cf.run()
209
+
210
+ for node, constant in cf.node_replacements.items():
211
+ if constraint_fn is not None and not constraint_fn(node):
212
+ continue
213
+ replace_node_with_constant(gm, node, constant)
214
+
215
+ erased_params = []
216
+ # Get all attr users by looking up the graph instead from node.users, because in this case
217
+ # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
218
+
219
+ # opcode name target args kwargs
220
+ # ------------- ------------------- ---------------- --------------------------- --------
221
+ # placeholder arg0_1 arg0 () {}
222
+ # get_attr _tensor_constant0 state () {}
223
+ # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
224
+ # get_attr _tensor_constant0_1 state () {}
225
+ # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
226
+ # output output output ([add],) {}
227
+
228
+ get_attr_node_users = defaultdict(list)
229
+ for node in gm.graph.nodes:
230
+ if node.op == "get_attr":
231
+ get_attr_node_users[node.target].extend(node.users.keys())
232
+ for node in gm.graph.find_nodes(op="get_attr"):
233
+ if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
234
+ if hasattr(gm, node.target):
235
+ delattr(gm, node.target)
236
+ erased_params.append(node)
237
+ for node in erased_params:
238
+ gm.graph.erase_node(node)
239
+
240
+ gm.graph.eliminate_dead_code()
241
+ gm.graph.lint()
242
+ gm.recompile()
243
+
244
+
245
+ def constant_graph_tag(gm: torch.fx.GraphModule):
246
+ with torch.utils._python_dispatch._disable_current_modes():
247
+ cf = ConstantFolder(gm, skip_constructors=True)
248
+ cf.run()
249
+
250
+ for node in gm.graph.nodes:
251
+ if (
252
+ node.op == "get_attr"
253
+ or node in cf.node_replacements
254
+ or node in cf.replaced_uses
255
+ ):
256
+ node.meta[META_TAG] = CONST_MODULE_TAG
257
+ else:
258
+ node.meta[META_TAG] = MODULE_TAG
259
+
260
+
261
+ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
262
+ """
263
+ Construct a GraphModule which corresponds to the part which could be
264
+ constant folded in provided gm.
265
+ """
266
+
267
+ constant_graph_tag(gm)
268
+ # We rewrite the tags, if it's a constant being directly consumed, without
269
+ # any folding opportunity, we keep it in main gm.
270
+ for node in gm.graph.find_nodes(op="get_attr"):
271
+ used_to_fold = False
272
+ for u in node.users:
273
+ if u.meta[META_TAG] == CONST_MODULE_TAG:
274
+ used_to_fold = True
275
+ break
276
+ if not used_to_fold:
277
+ node.meta[META_TAG] = MODULE_TAG
278
+
279
+ new_graph = torch.fx.Graph()
280
+
281
+ node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
282
+ output_nodes = []
283
+ for node in gm.graph.nodes:
284
+ if node.meta[META_TAG] == MODULE_TAG:
285
+ continue
286
+
287
+ new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
288
+ node_remapping[node] = new_node
289
+
290
+ for user in node.users:
291
+ if user.meta[META_TAG] == MODULE_TAG:
292
+ output_nodes.append(new_node)
293
+ break
294
+
295
+ new_graph.output(tuple(output_nodes))
296
+ new_graph.lint()
297
+ new_gm = torch.fx.GraphModule(gm, new_graph)
298
+
299
+ return new_gm
.venv/lib/python3.11/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, Optional, Tuple, List
3
+
4
+ import torch
5
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
6
+ from torch._export.pass_infra.node_metadata import NodeMetadata
7
+ from torch._export.pass_infra.proxy_value import ProxyValue
8
+ from torch._ops import OpOverload
9
+
10
+ aten = torch.ops.aten
11
+
12
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: Dict[OpOverload, OpOverload] = {
13
+ aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
14
+ aten._assert_async.msg: aten._functional_assert_async.msg,
15
+ }
16
+
17
+
18
+ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
19
+ """
20
+ Functionalize ops with side effect in graph module by replacing the op with
21
+ functional version of it. A new dependency token (`dep_token`) will be
22
+ created and propagated through functional ops to output.
23
+ For example:
24
+ ```
25
+ def f(x):
26
+ sym_constrain_range(x.shape[0], min=1, max=3)
27
+ return x.add(3)
28
+ ```
29
+ Will be transformed to:
30
+ ```
31
+ def f(x):
32
+ dep_token0 = _make_dep_token()
33
+ dep_token1 = _functional_sym_constrain_range(
34
+ x.shape[0], min=1, max=3, dep_token=dep_token0
35
+ )
36
+
37
+ return x.add(3), dep_token1
38
+ ```
39
+ """
40
+
41
+ def __init__(self) -> None:
42
+ super().__init__()
43
+ self._dep_token: Optional[ProxyValue] = None
44
+ self._next_dep_token_index: Optional[int] = None
45
+
46
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
47
+ # Early return if no non-functional assertions.
48
+ if not any(
49
+ n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
50
+ for n in graph_module.graph.nodes
51
+ ):
52
+ return PassResult(graph_module=graph_module, modified=False)
53
+
54
+ gm = copy.deepcopy(graph_module)
55
+ self._dep_token = None
56
+ self._next_dep_token_index = None
57
+ return super().call(gm)
58
+
59
+ def call_operator(
60
+ self,
61
+ op: OpOverload,
62
+ args: Tuple[Argument, ...],
63
+ kwargs: Dict[str, Argument],
64
+ meta: NodeMetadata,
65
+ ) -> ProxyValue:
66
+ if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
67
+ return super().call_operator(op, args, kwargs, meta)
68
+
69
+ if self._dep_token is None:
70
+ self._dep_token = super().call_operator(
71
+ aten._make_dep_token,
72
+ args=(),
73
+ kwargs={},
74
+ meta=self._create_dummy_node_metadata(),
75
+ )
76
+ self._dep_token.node.name = "dep_token0"
77
+ self._next_dep_token_index = 1
78
+
79
+ self._dep_token = super().call_operator(
80
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
81
+ args=args,
82
+ kwargs={**kwargs, "dep_token": self._dep_token},
83
+ meta=meta,
84
+ )
85
+ assert self._next_dep_token_index is not None
86
+ self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
87
+ self._next_dep_token_index += 1
88
+
89
+ return self._dep_token
90
+
91
+ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
92
+ assert self._dep_token is not None
93
+
94
+ return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
.venv/lib/python3.11/site-packages/torch/_export/passes/lift_constants_pass.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import warnings
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import torch
7
+ from torch._export.verifier import SpecViolationError
8
+ from torch._guards import detect_fake_mode
9
+ from torch._library.fake_class_registry import FakeScriptObject
10
+ from torch._subclasses.fake_tensor import unset_fake_temporarily
11
+ from torch.export.exported_program import (
12
+ ArgumentSpec,
13
+ CustomObjArgument,
14
+ ExportGraphSignature,
15
+ InputKind,
16
+ InputSpec,
17
+ TensorArgument,
18
+ )
19
+
20
+
21
+ class ConstantAttrMap(collections.abc.MutableMapping):
22
+ """A mapping class that understands how to use module constants (tensors,
23
+ ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
24
+ but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to
25
+ the same underlying value (but we guarantee that they will `hash()` to the same value
26
+ if that's the case).
27
+ """
28
+
29
+ def __init__(self) -> None:
30
+ # Underlying dict that we use to implement this mapping.
31
+ self._constant_attrs: Dict[
32
+ Union[int, torch.Tensor, FakeScriptObject], List[Any]
33
+ ] = {}
34
+ # Map from the hash(ScriptObject) to the ScriptObject itself. Used for
35
+ # APIs like `__iter__` that should look like they're returning the
36
+ # original ScriptObjects.
37
+ self._script_object_map: Dict[int, torch.ScriptObject] = {}
38
+
39
+ def __getitem__(
40
+ self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
41
+ ) -> Any:
42
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
43
+ assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
44
+ return self._constant_attrs[real_key]
45
+
46
+ def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value):
47
+ # we shouldn't actually call this, should go to add() instead to handle aliasing
48
+ raise NotImplementedError(
49
+ """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
50
+ The same key can be mapped to multiple values, for handling constant aliasing."""
51
+ )
52
+
53
+ def add(
54
+ self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any
55
+ ) -> None:
56
+ if isinstance(key, torch.ScriptObject):
57
+ if hash(key) not in self._constant_attrs:
58
+ self._constant_attrs[hash(key)] = []
59
+ self._constant_attrs[hash(key)].append(value)
60
+ self._script_object_map[hash(key)] = key
61
+ elif isinstance(key, (torch.Tensor, FakeScriptObject)):
62
+ if key not in self._constant_attrs:
63
+ self._constant_attrs[key] = []
64
+ self._constant_attrs[key].append(value)
65
+ else:
66
+ raise TypeError(
67
+ f"Expected key to be a tensor or ScriptObject, got {type(key)}"
68
+ )
69
+
70
+ def __delitem__(self, key):
71
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
72
+
73
+ del self._constant_attrs[real_key]
74
+
75
+ def __iter__(self):
76
+ for key in self._constant_attrs:
77
+ if isinstance(key, int):
78
+ yield self._script_object_map[key]
79
+ else:
80
+ yield key
81
+
82
+ def __len__(self):
83
+ return len(self._constant_attrs)
84
+
85
+ def __contains__(self, key: object) -> bool:
86
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
87
+ return real_key in self._constant_attrs
88
+
89
+
90
+ def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
91
+ # The FQN of the constant tensor in the state dict should
92
+ # correspond to the module where the constant tensor was
93
+ # originally used.
94
+ if len(node.meta["nn_module_stack"]) == 0:
95
+ return constant_name
96
+ parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0]
97
+ if len(parent_fqn) > 0:
98
+ return f"{parent_fqn}.{constant_name}"
99
+ else:
100
+ return constant_name
101
+
102
+
103
+ def _get_first_fqn(
104
+ const_attrs: ConstantAttrMap,
105
+ key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],
106
+ ) -> Any:
107
+ fqns = const_attrs.get(key)
108
+ return fqns[0] if fqns else None
109
+
110
+
111
+ def lift_constants_pass(
112
+ gm: torch.fx.GraphModule,
113
+ graph_signature: ExportGraphSignature,
114
+ constant_attrs: ConstantAttrMap,
115
+ ) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]]:
116
+ """
117
+ Takes a graph module, graph signature, and modifies them implace to lift any
118
+ constants (tensors or custom classes) as inputs to the graph. Returns a
119
+ dictionary of names to constants.
120
+
121
+ Arguments:
122
+ gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift.
123
+ graph_signature (ExportGraphSignature): This graph signature will be
124
+ mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs.
125
+ constant_attrs (ConstantAttr): A mapping from a constant value to its
126
+ fully-qualified path in `gm`. This is used to maintain consistent
127
+ location of constants between the original module and the exported
128
+ version.
129
+
130
+ Returns:
131
+ A dictionary of fqn => constant value.
132
+ """
133
+ all_constants: Dict[
134
+ str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject]
135
+ ] = {}
136
+
137
+ inputs = graph_signature.input_specs
138
+ num_custom_obj = sum(
139
+ input_specs.kind == InputKind.CUSTOM_OBJ for input_specs in inputs
140
+ )
141
+ num_tensor_constants = sum(
142
+ input_specs.kind == InputKind.CONSTANT_TENSOR for input_specs in inputs
143
+ )
144
+
145
+ fake_mode = detect_fake_mode(
146
+ tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
147
+ )
148
+
149
+ first_user_input_loc, first_user_input = 0, None
150
+ for node in gm.graph.nodes:
151
+ if node.op == "placeholder" and node.name in graph_signature.user_inputs:
152
+ first_user_input = node
153
+ break
154
+ first_user_input_loc += 1
155
+
156
+ lifted_objs = ConstantAttrMap()
157
+ for node in gm.graph.nodes:
158
+ if node.op == "get_attr":
159
+ constant_val = getattr(gm, node.target)
160
+ if constant_val in lifted_objs:
161
+ # We already lifted this constant elsewhere. Just rewrite uses
162
+ # of this get_attr to point to the already-existing placeholder
163
+ # node.
164
+ const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
165
+ node.replace_all_uses_with(const_placeholder_node)
166
+ gm.graph.erase_node(node)
167
+ continue
168
+
169
+ # For ScriptObject, Tensor and FakeScriptObject constants:
170
+ # First check if the constant was an attribute on some module by
171
+ # consulting `constant_attrs` map. If it is, use the fqn that keeps
172
+ # its location consistent with the eager module.
173
+ #
174
+ # If it's not in the `constant_attrs` map, that means it's an inline
175
+ # constant (e.g. x + torch.tensor(0)), and thus did not have a
176
+ # specific location in the eager module. In that case, just generate
177
+ # some name and attach it to the module in which it was used.
178
+ if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)):
179
+ constant_kind = InputKind.CUSTOM_OBJ
180
+ constant_fqn = _get_first_fqn(constant_attrs, constant_val)
181
+ if constant_fqn is not None:
182
+ constant_name = constant_fqn.replace(".", "_")
183
+ else:
184
+ constant_name = f"lifted_custom_{num_custom_obj}"
185
+ constant_fqn = get_constant_fqn(node, constant_name)
186
+ num_custom_obj += 1
187
+ elif isinstance(constant_val, torch.Tensor):
188
+ # Remove the parameterness of constant_val
189
+ if isinstance(constant_val, torch.nn.Parameter):
190
+ warnings.warn(
191
+ f"{node.target} created when tracing {node.meta['stack_trace']} is a parameter. But"
192
+ f"it's not registered with register_parameter(). export will treat it as a constant tensor"
193
+ )
194
+ # We get the real data out of the parameter by disabling the surrounding fake mode.
195
+ with unset_fake_temporarily():
196
+ constant_val = constant_val.data
197
+ constant_kind = InputKind.CONSTANT_TENSOR
198
+ constant_fqn = _get_first_fqn(constant_attrs, constant_val)
199
+ if constant_fqn is not None:
200
+ constant_name = constant_fqn.replace(".", "_")
201
+ else:
202
+ constant_name = f"lifted_tensor_{num_tensor_constants}"
203
+ constant_fqn = get_constant_fqn(node, constant_name)
204
+ num_tensor_constants += 1
205
+ elif isinstance(constant_val, torch.fx.GraphModule):
206
+ continue
207
+ elif "LoweredBackendModule" in type(constant_val).__name__:
208
+ continue
209
+ else:
210
+ raise SpecViolationError(
211
+ f"getattr node {node} referencing unsupported type {type(constant_val)}"
212
+ )
213
+
214
+ with gm.graph.inserting_before(first_user_input):
215
+ # Insert the constant node before the first user input
216
+ const_placeholder_node = gm.graph.placeholder(constant_name)
217
+ # match target name with its node name in case there is name collision
218
+ # and suffix is added to node name in fx
219
+ const_placeholder_node.target = const_placeholder_node.name
220
+
221
+ for k, v in node.meta.items():
222
+ const_placeholder_node.meta[k] = v
223
+
224
+ # Once the FQN has been used, remove nn_module_stack, stack_trace
225
+ const_placeholder_node.meta.pop("nn_module_stack")
226
+ const_placeholder_node.meta.pop("stack_trace", None)
227
+
228
+ input_spec_arg: ArgumentSpec
229
+ if isinstance(constant_val, torch.Tensor):
230
+ if fake_mode is not None:
231
+ const_placeholder_node.meta["val"] = fake_mode.from_tensor(
232
+ constant_val, static_shapes=True
233
+ )
234
+ const_placeholder_node.meta["val"].constant = constant_val
235
+ else:
236
+ const_placeholder_node.meta["val"] = constant_val
237
+ input_spec_arg = TensorArgument(name=const_placeholder_node.name)
238
+ elif isinstance(constant_val, torch._C.ScriptObject):
239
+ class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
240
+ const_placeholder_node.meta["val"] = CustomObjArgument(
241
+ constant_fqn, class_fqn
242
+ )
243
+ input_spec_arg = CustomObjArgument(
244
+ name=const_placeholder_node.name, class_fqn=class_fqn
245
+ )
246
+ elif isinstance(constant_val, FakeScriptObject):
247
+ class_fqn = constant_val.script_class_name
248
+ const_placeholder_node.meta["val"] = CustomObjArgument(
249
+ constant_fqn, class_fqn, constant_val
250
+ )
251
+ input_spec_arg = CustomObjArgument(
252
+ name=const_placeholder_node.name,
253
+ class_fqn=class_fqn,
254
+ fake_val=constant_val,
255
+ )
256
+ else:
257
+ raise SpecViolationError(
258
+ f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
259
+ )
260
+
261
+ lifted_objs.add(constant_val, const_placeholder_node)
262
+ node.replace_all_uses_with(const_placeholder_node)
263
+ gm.graph.erase_node(node)
264
+
265
+ # Add the constant as a buffer to the graph signature
266
+ graph_signature.input_specs.insert(
267
+ first_user_input_loc,
268
+ InputSpec(
269
+ kind=constant_kind,
270
+ arg=input_spec_arg,
271
+ target=constant_fqn,
272
+ ),
273
+ )
274
+ if constant_val in constant_attrs:
275
+ for fqn in constant_attrs[constant_val]:
276
+ all_constants[fqn] = constant_val
277
+ else:
278
+ all_constants[constant_fqn] = constant_val
279
+ first_user_input_loc += 1
280
+
281
+ return all_constants
282
+
283
+
284
+ def rewrite_script_object_meta(
285
+ gm: torch.fx.GraphModule,
286
+ ) -> Dict[str, Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],]:
287
+ """When tracing, we produce a graph with FakeScriptObject in the
288
+ meta["val"].
289
+
290
+ For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
291
+ """
292
+ constants: Dict[
293
+ str,
294
+ Union[
295
+ torch.Tensor,
296
+ torch.ScriptObject,
297
+ FakeScriptObject,
298
+ ],
299
+ ] = {}
300
+ for node in gm.graph.nodes:
301
+ if "val" not in node.meta:
302
+ continue
303
+
304
+ old_meta = node.meta["val"]
305
+
306
+ if isinstance(old_meta, torch.ScriptObject):
307
+ class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
308
+ new_meta = CustomObjArgument(node.name, class_fqn)
309
+ constants[node.name] = old_meta
310
+ node.meta["val"] = new_meta
311
+
312
+ elif isinstance(old_meta, FakeScriptObject):
313
+ class_fqn = old_meta.script_class_name # type: ignore[attr-defined]
314
+ new_meta = CustomObjArgument(node.name, class_fqn, old_meta)
315
+ constants[node.name] = old_meta
316
+ node.meta["val"] = new_meta
317
+
318
+ return constants
.venv/lib/python3.11/site-packages/torch/_export/passes/remove_runtime_assertions.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch
3
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
4
+
5
+
6
+ class _RemoveRuntimeAssertionsPass(PassBase):
7
+ """
8
+ Remove runtime assertions inserted by the
9
+ _AddRuntimeAssertionsForInlineConstraintsPass.
10
+ """
11
+
12
+ def call(self, graph_module) -> PassResult:
13
+ modified = False
14
+ for module in graph_module.modules():
15
+ if not isinstance(module, torch.fx.GraphModule):
16
+ continue
17
+ for node in module.graph.nodes:
18
+ if node.target == torch.ops.aten._assert_async.msg:
19
+ assert_async_node = node
20
+ if len(assert_async_node.users) > 0:
21
+ continue
22
+ module.graph.erase_node(assert_async_node)
23
+ # the upstream scalar_tensor <- {le, ge} <- sym_size
24
+ # linear chain of nodes of nodes is removed by the
25
+ # downstream dead code elimination
26
+ modified = True
27
+ return PassResult(graph_module, modified)
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import List
3
+
4
+ import torch
5
+ from torch._higher_order_ops.wrap import wrap_with_autocast
6
+
7
+ from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split
8
+ from .replace_with_hop_pass_util import (
9
+ _replace_with_hop_helper,
10
+ _replace_with_hop_pass_helper,
11
+ _sequential_split_and_maybe_inline_subgraphs_helper,
12
+ )
13
+
14
+
15
+ def _is_autocast_node(node: torch.fx.Node):
16
+ return (
17
+ node
18
+ and node.op == "call_function"
19
+ and node.target
20
+ in [
21
+ torch.amp.autocast_mode._enter_autocast,
22
+ torch.amp.autocast_mode._exit_autocast,
23
+ ]
24
+ )
25
+
26
+
27
+ def _is_enter_autocast_node(node: torch.fx.Node):
28
+ return (
29
+ node
30
+ and node.op == "call_function"
31
+ and node.target == torch.amp.autocast_mode._enter_autocast
32
+ )
33
+
34
+
35
+ def _is_exit_autocast_node(node: torch.fx.Node):
36
+ return (
37
+ node
38
+ and node.op == "call_function"
39
+ and node.target == torch.amp.autocast_mode._exit_autocast
40
+ )
41
+
42
+
43
+ def _is_autocast_sub_mod(node: torch.fx.Node):
44
+ """
45
+ Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`.
46
+ """
47
+ if node.op == "call_module":
48
+ assert isinstance(node.target, str)
49
+ subgm = getattr(node.graph.owning_module, node.target)
50
+ first_non_ph = nodes_first(
51
+ subgm.graph.nodes, lambda node: node.op != "placeholder"
52
+ )
53
+ if (
54
+ first_non_ph
55
+ and first_non_ph.op == "call_function"
56
+ and first_non_ph.target == torch.amp.autocast_mode._enter_autocast
57
+ ):
58
+ # TODO: check if current auto-cast type is the same as the args of
59
+ # _enter_autocast. If so, return False, i.e. do not create a submodule.
60
+ return True
61
+ return False
62
+
63
+
64
+ def _check_valid_autocast_block(enter_autocast_node, exit_autocast_node):
65
+ assert _is_enter_autocast_node(enter_autocast_node)
66
+ assert _is_exit_autocast_node(exit_autocast_node)
67
+ assert exit_autocast_node.args[0] == enter_autocast_node
68
+
69
+
70
+ def _replace_with_hop(node: torch.fx.Node):
71
+ assert node.op == "call_module"
72
+ graph: torch.fx.Graph = node.graph
73
+ gm: torch.fx.GraphModule = graph.owning_module
74
+ assert isinstance(node.target, str)
75
+ sub_gm = getattr(gm, node.target)
76
+ sub_graph = sub_gm.graph
77
+ autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node)
78
+ if len(autocast_nodes) > 0:
79
+ assert len(autocast_nodes) > 1 # need at least an enter node and an exist node
80
+ enter_autocast_node = autocast_nodes[0]
81
+ exit_autocast_node = autocast_nodes[-1]
82
+ _check_valid_autocast_block(enter_autocast_node, exit_autocast_node)
83
+
84
+ _replace_with_hop_helper(
85
+ node, enter_autocast_node, _is_autocast_node, wrap_with_autocast
86
+ )
87
+ sub_graph.erase_node(exit_autocast_node)
88
+ sub_graph.erase_node(enter_autocast_node)
89
+
90
+
91
+ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
92
+ """
93
+ split_autocast creates a new graph module that splits the input graph module into multiple submodules
94
+ based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module.
95
+
96
+ Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted
97
+ into a submodule. Nested autocast regions are not splitted.
98
+ `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well.
99
+
100
+ Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph
101
+ module. Nodes marked with the same number are grouped into the same submodule.
102
+ A # 0
103
+ enter_autocast # 1
104
+ B # 1
105
+ exit_autocast # 1
106
+ C # 2
107
+ enter_autocast # 3
108
+ D # 3
109
+ exit_autocast # 3
110
+ E # 4
111
+ """
112
+ enter_autocast_node_stack: List[torch.fx.Node] = []
113
+ first_node_after_outer_most_exit: bool = False
114
+
115
+ def node_call_back(node: torch.fx.Node):
116
+ nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit
117
+ if first_node_after_outer_most_exit or (
118
+ len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node)
119
+ ):
120
+ assert len(enter_autocast_node_stack) == 0
121
+ first_node_after_outer_most_exit = False
122
+ if _is_enter_autocast_node(node):
123
+ enter_autocast_node_stack.append(node)
124
+ return True
125
+ if _is_exit_autocast_node(node):
126
+ assert len(enter_autocast_node_stack) > 0
127
+ last_enter_autocast_node = enter_autocast_node_stack.pop()
128
+ assert node.args[0] == last_enter_autocast_node
129
+ if len(enter_autocast_node_stack) == 0:
130
+ # next node should be in the next submodule since
131
+ # autocast block ends
132
+ first_node_after_outer_most_exit = True
133
+ return False
134
+
135
+ return sequential_split(gm, node_call_back)
136
+
137
+
138
+ def _sequential_split_and_maybe_inline_subgraphs(
139
+ gm: torch.fx.GraphModule, graph_signature
140
+ ):
141
+ """
142
+ Helper function for replace_autocast_with_hop_pass().
143
+ Split the graph module into multiple subgraphs based on the autocast nodes.
144
+ For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
145
+ back into the parent graph module.
146
+ Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered
147
+ as a subgraph.
148
+ """
149
+ need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes)
150
+ if not need_replacing:
151
+ return gm, graph_signature
152
+
153
+ # split_autocast returns a new graph module that could have different output
154
+ # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`.
155
+ new_gm = _split_autocast(gm)
156
+
157
+ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
158
+ if _is_autocast_sub_mod(node):
159
+ _replace_with_hop(node)
160
+ else:
161
+ assert node.op == "call_module"
162
+ assert isinstance(node.target, str)
163
+ node_inline_(node)
164
+
165
+ return _sequential_split_and_maybe_inline_subgraphs_helper(
166
+ new_gm, graph_signature, _maybe_inline_or_replace_with_hop
167
+ )
168
+
169
+
170
+ def replace_autocast_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
171
+ """
172
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
173
+ then recursively call itself on each of the submodules.
174
+ """
175
+ return _replace_with_hop_pass_helper(
176
+ gm,
177
+ graph_signature,
178
+ _sequential_split_and_maybe_inline_subgraphs,
179
+ )
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import operator
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.export._trace
8
+ from torch._ops import OpOverload
9
+ from torch.ao.quantization.fx._decomposed import (
10
+ dequantize_per_channel,
11
+ dequantize_per_tensor,
12
+ quantize_per_tensor,
13
+ )
14
+ from torch.ao.quantization.utils import calculate_qmin_qmax
15
+ from torch.fx.graph_module import _assign_attr
16
+
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ # Those values will need to be carried over multiple operators.
21
+ _INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None
22
+ _SCALE: Optional[Union[float, torch.fx.Node]] = None
23
+ _ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None
24
+
25
+
26
+ def int_to_valid_dtype(val: int) -> torch.dtype:
27
+ from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import.
28
+
29
+ if isinstance(val, torch.dtype):
30
+ return val
31
+ dtype = _TORCH_ENUM_TO_DTYPE[val]
32
+ if dtype == torch.quint8:
33
+ return torch.uint8
34
+ elif dtype == torch.qint8:
35
+ return torch.int8
36
+ return dtype
37
+
38
+
39
+ def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node:
40
+ return gm.graph.call_function(int_to_valid_dtype, (val,))
41
+
42
+
43
+ def insert_quantized_node(
44
+ gm: torch.fx.GraphModule,
45
+ val_node: torch.fx.Node,
46
+ scale_node: Union[float, torch.fx.Node],
47
+ zero_point_node: Union[float, torch.fx.Node],
48
+ qmin_node: Union[float, int, torch.fx.Node],
49
+ qmax_node: Union[float, int, torch.fx.Node],
50
+ dtype_node: Union[torch.dtype, torch.fx.Node],
51
+ qscheme: Optional[torch.qscheme],
52
+ ) -> torch.fx.Node:
53
+ return gm.graph.call_function(
54
+ quantize_per_tensor,
55
+ (
56
+ val_node,
57
+ scale_node,
58
+ zero_point_node,
59
+ qmin_node,
60
+ qmax_node,
61
+ dtype_node,
62
+ ),
63
+ )
64
+
65
+
66
+ def get_dequantized(
67
+ val: torch.Tensor,
68
+ scale: Union[float, torch.Tensor],
69
+ zero_point: Union[float, torch.Tensor],
70
+ qmin: Union[float, int],
71
+ qmax: Union[float, int],
72
+ dtype: torch.dtype,
73
+ axis: Optional[int],
74
+ qscheme: Optional[torch.qscheme],
75
+ ) -> torch.Tensor:
76
+ if qscheme is torch.per_tensor_affine:
77
+ return dequantize_per_tensor(
78
+ val,
79
+ scale,
80
+ zero_point,
81
+ qmin,
82
+ qmax,
83
+ dtype,
84
+ )
85
+ elif qscheme is torch.per_channel_affine:
86
+ return dequantize_per_channel(
87
+ val,
88
+ scale,
89
+ zero_point,
90
+ axis,
91
+ qmin,
92
+ qmax,
93
+ dtype,
94
+ )
95
+ else:
96
+ raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
97
+
98
+
99
+ def insert_dequantized_node(
100
+ gm: torch.fx.GraphModule,
101
+ val_node: torch.fx.Node,
102
+ scale_node: Union[float, torch.fx.Node],
103
+ zero_point_node: Union[float, torch.fx.Node],
104
+ qmin_node: Union[float, int, torch.fx.Node],
105
+ qmax_node: Union[float, int, torch.fx.Node],
106
+ dtype_node: Union[torch.dtype, torch.fx.Node],
107
+ axis_node: Optional[Union[int, torch.fx.Node]],
108
+ qscheme: Optional[torch.qscheme],
109
+ ) -> torch.fx.Node:
110
+ if qscheme is torch.per_tensor_affine:
111
+ return gm.graph.call_function(
112
+ dequantize_per_tensor,
113
+ (
114
+ val_node,
115
+ scale_node,
116
+ zero_point_node,
117
+ qmin_node,
118
+ qmax_node,
119
+ dtype_node,
120
+ ),
121
+ )
122
+ elif qscheme is torch.per_channel_affine:
123
+ return gm.graph.call_function(
124
+ dequantize_per_channel,
125
+ (
126
+ val_node,
127
+ scale_node,
128
+ zero_point_node,
129
+ axis_node,
130
+ qmin_node,
131
+ qmax_node,
132
+ dtype_node,
133
+ ),
134
+ )
135
+ else:
136
+ raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
137
+
138
+
139
+ def get_qmin_qmax(dtype: torch.dtype) -> Tuple[Union[int, float], Union[int, float]]:
140
+ return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type]
141
+
142
+
143
+ def insert_qmin_qmax_node(
144
+ gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node]
145
+ ) -> Tuple[torch.fx.Node, torch.fx.Node]:
146
+ q_min_max_node = gm.graph.call_function(
147
+ calculate_qmin_qmax, (None, None, False, dtype_node, False)
148
+ )
149
+ qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0))
150
+ qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1))
151
+ return qmin_node, qmax_node
152
+
153
+
154
+ def get_script_object(
155
+ gm: torch.nn.Module, node: torch.fx.Node
156
+ ) -> torch._C.ScriptObject:
157
+ assert isinstance(node, torch.fx.Node)
158
+ assert node.op == "get_attr"
159
+ attr_name = node.target
160
+ assert isinstance(attr_name, str)
161
+
162
+ mod = gm
163
+ for attr in attr_name.split("."):
164
+ mod = getattr(mod, attr)
165
+ assert isinstance(mod, torch._C.ScriptObject)
166
+ return mod
167
+
168
+
169
+ def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
170
+ gm: torch.fx.GraphModule,
171
+ param_node: torch.fx.Node,
172
+ ) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
173
+ """Directly inline tensor from a get_attr fx node."""
174
+ mod = get_script_object(gm, param_node)
175
+ w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined]
176
+ w_attr_name, b_attr_name = (
177
+ f"dequantized_{param_node.target}_w",
178
+ f"dequantized_{param_node.target}_b",
179
+ )
180
+ return insert_weight_and_bias_get_attr_node(
181
+ gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
182
+ )
183
+
184
+
185
+ def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
186
+ gm: torch.fx.GraphModule,
187
+ get_attr_to_weight_node: torch.fx.Node,
188
+ get_attr_to_bias_node: Optional[torch.fx.Node],
189
+ ) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
190
+ assert isinstance(get_attr_to_weight_node.target, str)
191
+ w_qtensor = getattr(gm, get_attr_to_weight_node.target)
192
+ w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w"
193
+
194
+ if get_attr_to_bias_node is not None:
195
+ assert isinstance(get_attr_to_bias_node.target, str)
196
+ b_qtensor = getattr(gm, get_attr_to_bias_node.target)
197
+ b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b"
198
+ else:
199
+ b_qtensor, b_attr_name = None, ""
200
+
201
+ return insert_weight_and_bias_get_attr_node(
202
+ gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
203
+ )
204
+
205
+
206
+ def insert_weight_and_bias_get_attr_node(
207
+ gm: torch.fx.GraphModule,
208
+ w_qtensor: torch.Tensor,
209
+ b_qtensor: Optional[torch.Tensor],
210
+ w_attr_name: str,
211
+ b_attr_name: str,
212
+ ) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
213
+ w_tensor = get_tensor_from_qtensor(w_qtensor)
214
+ _assign_attr(w_tensor, gm, w_attr_name)
215
+ w_tensor_attr = gm.graph.get_attr(w_attr_name)
216
+
217
+ if b_qtensor is not None:
218
+ b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False)
219
+ _assign_attr(b_tensor, gm, b_attr_name)
220
+ b_tensor_attr = gm.graph.get_attr(b_attr_name)
221
+ else:
222
+ b_tensor_attr = None
223
+
224
+ return w_tensor_attr, b_tensor_attr
225
+
226
+
227
+ def get_tensor_from_qtensor(
228
+ qtensor: torch.Tensor, dequant: bool = True
229
+ ) -> torch.Tensor:
230
+ # Manual conversion because qint8 is not used anymore.
231
+ if qtensor.dtype in [torch.qint8, torch.quint8]:
232
+ tensor = qtensor.int_repr()
233
+ else:
234
+ tensor = qtensor
235
+
236
+ # Weights need dequantization with scaling and zero_point adjustment, but
237
+ # bias does not need that.
238
+ if dequant:
239
+ qscheme = qtensor.qscheme()
240
+ if qscheme == torch.per_channel_affine:
241
+ scale, zero_point, axis = (
242
+ qtensor.q_per_channel_scales(),
243
+ qtensor.q_per_channel_zero_points(),
244
+ qtensor.q_per_channel_axis(),
245
+ )
246
+ else:
247
+ scale, zero_point, axis = (
248
+ qtensor.q_scale(), # type: ignore[assignment]
249
+ qtensor.q_zero_point(), # type: ignore[assignment]
250
+ None,
251
+ )
252
+ dtype = tensor.dtype
253
+ qmin, qmax = get_qmin_qmax(dtype)
254
+ return get_dequantized(
255
+ tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme
256
+ )
257
+ return tensor
258
+
259
+
260
+ def insert_fused_activation_node(
261
+ gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node
262
+ ) -> torch.fx.Node:
263
+ if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]:
264
+ fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,))
265
+ return fx_node
266
+
267
+
268
+ def _conv1d_op_with_squeeze(
269
+ inp: torch.Tensor,
270
+ weight: torch.Tensor,
271
+ bias: Optional[torch.Tensor],
272
+ stride: List[int],
273
+ padding: List[int],
274
+ dilation: List[int],
275
+ groups: int,
276
+ ) -> torch.Tensor:
277
+ # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
278
+ # operations before and after the conv2d operation to match the dimension of weights.
279
+ # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950
280
+ s_inp = torch.ops.aten.unsqueeze(inp, 2)
281
+ conv1d_res = torch.ops.aten.conv2d(
282
+ s_inp,
283
+ weight,
284
+ bias,
285
+ stride,
286
+ padding,
287
+ dilation,
288
+ groups,
289
+ )
290
+ uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2)
291
+ return uns_conv1d_res
292
+
293
+
294
+ def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
295
+ """Conv specfic transformation function."""
296
+ assert isinstance(node.target, torch._ops.OpOverload)
297
+ opname = node.target._opname
298
+ scale_node, zero_point_node = node.args[2], node.args[3]
299
+
300
+ op_f = (
301
+ torch.ops.aten.conv2d
302
+ if opname in ["conv2d", "conv2d_relu"]
303
+ else _conv1d_op_with_squeeze
304
+ )
305
+
306
+ inp_node, param_node = node.args[0], node.args[1]
307
+ assert isinstance(inp_node, torch.fx.Node)
308
+ assert isinstance(param_node, torch.fx.Node)
309
+
310
+ if param_node.op == "call_function":
311
+ # Using Conv2dPrepackParam from conv_prepack.
312
+ # We directly skip the packing call and inline weights and bias.
313
+ w_node, b_node = param_node.args[0], param_node.args[1]
314
+ assert isinstance(w_node, torch.fx.Node)
315
+ assert b_node is None or isinstance(b_node, torch.fx.Node)
316
+ (
317
+ param_0,
318
+ param_1,
319
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
320
+ gm, w_node, b_node
321
+ )
322
+ op_res_node = gm.graph.call_function(
323
+ op_f, (inp_node, param_0, param_1, *param_node.args[2:])
324
+ )
325
+ else:
326
+ # Using ConvPrepackedParam.
327
+ param = get_script_object(gm, param_node)
328
+ (
329
+ param_0,
330
+ param_1,
331
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
332
+ gm, param_node
333
+ ) # type: ignore[assignment]
334
+ op_res_node = gm.graph.call_function(
335
+ op_f,
336
+ (
337
+ inp_node,
338
+ param_0,
339
+ param_1,
340
+ param.stride(), # type: ignore[attr-defined]
341
+ param.padding(), # type: ignore[attr-defined]
342
+ param.dilation(), # type: ignore[attr-defined]
343
+ param.groups(), # type: ignore[attr-defined]
344
+ ),
345
+ )
346
+ return op_res_node, scale_node, zero_point_node
347
+
348
+
349
+ def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
350
+ """Linear specfic transformation function."""
351
+ scale_node, zero_point_node = node.args[2], node.args[3]
352
+
353
+ inp_node, param_node = node.args[0], node.args[1]
354
+ assert isinstance(inp_node, torch.fx.Node)
355
+ assert isinstance(param_node, torch.fx.Node)
356
+
357
+ if param_node.op == "call_function":
358
+ # Using LinearPrepackParam from linear_prepack.
359
+ # We directly skip the packing call and inline weights and bias.
360
+ w_node, b_node = param_node.args[0], param_node.args[1]
361
+ assert isinstance(w_node, torch.fx.Node)
362
+ assert b_node is None or isinstance(b_node, torch.fx.Node)
363
+ (
364
+ param_0,
365
+ param_1,
366
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
367
+ gm, w_node, b_node
368
+ )
369
+ op_res_node = gm.graph.call_function(
370
+ torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:])
371
+ )
372
+ else:
373
+ # Using LinearPackedParams.
374
+ (
375
+ param_0,
376
+ param_1,
377
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
378
+ gm, param_node
379
+ ) # type: ignore[assignment]
380
+ op_res_node = gm.graph.call_function(
381
+ torch.ops.aten.linear, (inp_node, param_0, param_1)
382
+ )
383
+ return op_res_node, scale_node, zero_point_node
384
+
385
+
386
+ def _transform_op_where_last_two_arguments_are_scale_and_zero_point(
387
+ gm: torch.fx.GraphModule, node: torch.fx.Node
388
+ ):
389
+ """
390
+ This transformation function can be used for function where the last two
391
+ parameters are scale and zero point. Additionally, the function's parameters
392
+ do not need any unpacking.
393
+ """
394
+ to_standard_op = {
395
+ "mul": torch.ops.aten.mul,
396
+ "mul_relu": torch.ops.aten.mul,
397
+ "add": torch.ops.aten.add,
398
+ "add_relu": torch.ops.aten.add,
399
+ "softmax": torch.ops.aten.softmax,
400
+ "cat": torch.ops.aten.cat,
401
+ "hardswish": torch.ops.aten.hardswish,
402
+ }
403
+
404
+ assert isinstance(node.target, torch._ops.OpOverload)
405
+ opname, args = node.target._opname, node.args
406
+ scale_node, zero_point_node = args[-2], args[-1]
407
+ op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2]))
408
+ return op_res_node, scale_node, zero_point_node
409
+
410
+
411
+ def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node):
412
+ """Transform scalar overload for basic arithmetic."""
413
+ to_standard_op = {
414
+ "mul": torch.ops.aten.mul.Scalar,
415
+ "add": torch.ops.aten.add.Scalar,
416
+ }
417
+ assert isinstance(node.target, torch._ops.OpOverload)
418
+ opname, args = node.target._opname, node.args
419
+ op_res_node = gm.graph.call_function(to_standard_op[opname], args)
420
+ return op_res_node, _SCALE, _ZERO_POINT
421
+
422
+
423
+ def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
424
+ """
425
+ Transformation for functions under prepacked namespace, where they share
426
+ the same handling logic that [...]OpContext contains all parameters.
427
+ """
428
+ assert isinstance(node.target, torch._ops.OpOverload)
429
+ opname, args = node.target._opname, node.args
430
+ op_f = None
431
+ if opname == "conv2d_clamp_run":
432
+ op_f = torch.ops.aten.conv2d
433
+ elif opname == "linear_clamp_run":
434
+ op_f = torch.ops.aten.linear
435
+ else:
436
+ raise RuntimeError(f"Invalid operator {opname}")
437
+
438
+ assert isinstance(args[1], torch.fx.Node)
439
+ so = get_script_object(gm, args[1])
440
+
441
+ func_args = []
442
+ func_args += [args[0]]
443
+ func_args += so.unpack()[:2] # type: ignore[attr-defined]
444
+ if opname == "conv2d_clamp_run":
445
+ func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]
446
+
447
+ op_res_node = gm.graph.call_function(op_f, tuple(func_args))
448
+ return op_res_node
449
+
450
+
451
+ def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
452
+ args = node.args
453
+ scale_node, zero_point_node = args[-2], args[-1]
454
+ op_res_node = gm.graph.call_function(
455
+ torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3])
456
+ )
457
+ op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0))
458
+ return op_res_node, scale_node, zero_point_node
459
+
460
+
461
+ def fx_transform_quantized_op_to_standard_op(
462
+ gm: torch.fx.GraphModule, node: torch.fx.Node
463
+ ) -> torch.fx.Node:
464
+ global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE
465
+
466
+ assert isinstance(node.target, torch._ops.OpOverload)
467
+ opname, overload = node.target._opname, node.target._overloadname
468
+
469
+ key = f"{opname}.{overload}"
470
+ opname_to_transform_f = {
471
+ "conv1d.new": _transform_conv_with_packedparam,
472
+ "conv1d_relu.new": _transform_conv_with_packedparam,
473
+ "conv1d.default": _transform_conv_with_packedparam,
474
+ "conv1d_relu.default": _transform_conv_with_packedparam,
475
+ "conv2d.new": _transform_conv_with_packedparam,
476
+ "conv2d_relu.new": _transform_conv_with_packedparam,
477
+ "conv2d.default": _transform_conv_with_packedparam,
478
+ "conv2d_relu.default": _transform_conv_with_packedparam,
479
+ "linear.default": _transform_linear_with_packedparam,
480
+ "linear_relu.default": _transform_linear_with_packedparam,
481
+ "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
482
+ "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
483
+ "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
484
+ "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
485
+ "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
486
+ "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
487
+ "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
488
+ "batch_norm2d.default": _transform_batch_norm,
489
+ "mul.Scalar": _transform_scalar_arithmetic,
490
+ "add.Scalar": _transform_scalar_arithmetic,
491
+ }
492
+
493
+ if f"{key}" not in opname_to_transform_f:
494
+ raise RuntimeError(f"Unsupported quantized op during transformation: {key}")
495
+
496
+ op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node)
497
+
498
+ # Add fused activation layer.
499
+ op_res_node = insert_fused_activation_node(gm, opname, op_res_node)
500
+ _SCALE, _ZERO_POINT = scale_node, zero_point_node
501
+
502
+ assert _INPUT_Q_DTYPE is not None
503
+ qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE)
504
+ q_fx_node = insert_quantized_node(
505
+ gm,
506
+ op_res_node,
507
+ scale_node,
508
+ zero_point_node,
509
+ qmin_node,
510
+ qmax_node,
511
+ _INPUT_Q_DTYPE,
512
+ torch.per_tensor_affine,
513
+ )
514
+ dq_fx_node = insert_dequantized_node(
515
+ gm,
516
+ q_fx_node,
517
+ scale_node,
518
+ zero_point_node,
519
+ qmin_node,
520
+ qmax_node,
521
+ _INPUT_Q_DTYPE,
522
+ None,
523
+ torch.per_tensor_affine,
524
+ )
525
+ return dq_fx_node
526
+
527
+
528
+ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
529
+ """
530
+ Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with
531
+ PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv).
532
+
533
+ Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y
534
+
535
+ After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y
536
+
537
+ (qd == quantized_decomposed library, q = quantize, dq = dequantize)
538
+ ^
539
+ |
540
+ getattr(w), getattr(b) from Conv2dParamPrepack
541
+
542
+ During each iteration, the transformation spits out the transformed operator, its quantized output,
543
+ and its dequantized value together. We did this because dequantization need to use the
544
+ scale and zero point parameters from the quantization to recover the approximate original value. After each
545
+ iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear).
546
+
547
+ For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject.
548
+ During the transformation, we unpack those objects, get their dequantized tensor, populate those
549
+ as attributes to the module, and use getattr to access them.
550
+
551
+ One exception in the transformation is conv_prepack and linear_prepack. Those calls pack
552
+ weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls.
553
+ During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the
554
+ quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
555
+ to the operator by converting them to a getattr fx.node.
556
+
557
+ For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
558
+ without the need of doing de/quantization.
559
+
560
+ Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
561
+ data type, which is the same across the entire program, but it only shows up in the very first quantization
562
+ call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
563
+ """
564
+
565
+ global _INPUT_Q_DTYPE
566
+
567
+ quantized = False
568
+
569
+ last_quantized_node = None
570
+ for node in gm.graph.nodes:
571
+ if isinstance(node.target, OpOverload):
572
+ with gm.graph.inserting_before(node):
573
+ namespace, opname = node.target.namespace, node.target._opname
574
+ if namespace == "quantized" and opname not in [
575
+ "conv_prepack",
576
+ "linear_prepack",
577
+ ]:
578
+ quantized = True
579
+ fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
580
+ node.replace_all_uses_with(fx_node)
581
+ last_quantized_node = fx_node
582
+ elif namespace == "prepacked":
583
+ quantized = True
584
+ fx_node = _transform_prepacked_op(gm, node)
585
+ node.replace_all_uses_with(fx_node)
586
+ last_quantized_node = fx_node
587
+ elif namespace == "aten" and opname == "quantize_per_tensor":
588
+ inp_node, scale_node, zero_point_node, dtype_node = node.args
589
+ dtype_node = fx_enum_to_dtype(gm, dtype_node)
590
+ _INPUT_Q_DTYPE = dtype_node
591
+ qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node)
592
+ q_fx_node = insert_quantized_node(
593
+ gm,
594
+ inp_node,
595
+ scale_node,
596
+ zero_point_node,
597
+ qmin_node,
598
+ qmax_node,
599
+ dtype_node,
600
+ torch.per_tensor_affine,
601
+ )
602
+ dq_fx_node = insert_dequantized_node(
603
+ gm,
604
+ q_fx_node,
605
+ scale_node,
606
+ zero_point_node,
607
+ qmin_node,
608
+ qmax_node,
609
+ dtype_node,
610
+ None,
611
+ torch.per_tensor_affine,
612
+ )
613
+ node.replace_all_uses_with(dq_fx_node)
614
+ last_quantized_node = dq_fx_node
615
+ elif namespace == "aten" and opname == "dequantize":
616
+ assert last_quantized_node is not None
617
+ node.replace_all_uses_with(last_quantized_node)
618
+ else:
619
+ last_quantized_node = node
620
+
621
+ # Post-processing again to remove legacy ScriptObjects and quantizated tensors
622
+ # stored as attributes or in the buffer. This is used to clean up the GraphModule
623
+ # to not trigger tracing errors like missing __obj_flatten__ functions.
624
+ def _clean_attr(mod: torch.nn.Module):
625
+ for submod in mod.modules():
626
+ attr_names_to_clean = set()
627
+ for k, v in submod.__dict__.items():
628
+ if isinstance(v, torch.ScriptObject):
629
+ attr_names_to_clean.add(k)
630
+ if k == "_buffers":
631
+ buffer_name_to_clean = set()
632
+ for b_name, b_value in v.items():
633
+ if isinstance(b_value, torch.Tensor) and b_value.dtype in [
634
+ torch.qint8,
635
+ torch.quint8,
636
+ ]:
637
+ buffer_name_to_clean.add(b_name)
638
+ for b_name in buffer_name_to_clean:
639
+ v.pop(b_name, None)
640
+ for attr_name in attr_names_to_clean:
641
+ delattr(submod, attr_name)
642
+
643
+ if quantized:
644
+ """
645
+ TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily
646
+ bypass test cases.
647
+
648
+ The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing
649
+ will throw errors. However, the current way of SetAttr does inplace update to attributes, so
650
+ this pass regard them as dead code and remove them. Below is an example of GraphModule before
651
+ and after the dead code elimination pass.
652
+
653
+ class GraphModule(torch.nn.Module):
654
+ def forward(self, x_1):
655
+ # No stacktrace found for following nodes
656
+ data = self.data; data = None
657
+ data_1 = self.data
658
+ add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None
659
+ data_2 = self.data
660
+ copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None
661
+ data_3 = self.data
662
+ add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
663
+ return add_tensor_1
664
+
665
+ class GraphModule(torch.nn.Module):
666
+ def forward(self, x_1):
667
+ # No stacktrace found for following nodes
668
+ data_3 = self.data
669
+ add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
670
+ return add_tensor_1
671
+ """
672
+ gm.graph.eliminate_dead_code()
673
+ _clean_attr(gm)
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch
4
+ from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
5
+
6
+ from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split
7
+ from .replace_with_hop_pass_util import (
8
+ _replace_with_hop_helper,
9
+ _replace_with_hop_pass_helper,
10
+ _sequential_split_and_maybe_inline_subgraphs_helper,
11
+ )
12
+
13
+
14
+ def _is_set_grad_enabled_node(node: torch.fx.Node):
15
+ return (
16
+ node
17
+ and node.op == "call_function"
18
+ and node.target == torch._C._set_grad_enabled
19
+ )
20
+
21
+
22
+ def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False):
23
+ if node.op == "call_module":
24
+ assert isinstance(node.target, str)
25
+ subgm = getattr(node.graph.owning_module, node.target)
26
+ first_non_ph = nodes_first(
27
+ subgm.graph.nodes, lambda node: node.op != "placeholder"
28
+ )
29
+ if (
30
+ first_non_ph
31
+ and first_non_ph.op == "call_function"
32
+ and first_non_ph.target == torch._C._set_grad_enabled
33
+ ):
34
+ return (
35
+ first_non_ph.args[0] != torch.is_grad_enabled()
36
+ if omit_if_same_with_ambient
37
+ else True
38
+ )
39
+ return False
40
+
41
+
42
+ def _replace_with_hop(node: torch.fx.Node):
43
+ assert node.op == "call_module"
44
+ graph: torch.fx.Graph = node.graph
45
+ gm: torch.fx.GraphModule = graph.owning_module
46
+ assert isinstance(node.target, str)
47
+ sub_gm = getattr(gm, node.target)
48
+ sub_graph = sub_gm.graph
49
+ set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
50
+ if len(set_grad_nodes) > 0:
51
+ assert len(set_grad_nodes) == 1
52
+ set_grad_node = set_grad_nodes[0]
53
+ _replace_with_hop_helper(
54
+ node, set_grad_node, _is_set_grad_enabled_node, wrap_with_set_grad_enabled
55
+ )
56
+ sub_graph.erase_node(set_grad_node)
57
+
58
+
59
+ def _remove_set_grad_and_inline(node: torch.fx.Node):
60
+ assert node.op == "call_module"
61
+ graph: torch.fx.Graph = node.graph
62
+ gm: torch.fx.GraphModule = graph.owning_module
63
+ assert isinstance(node.target, str)
64
+ sub_gm = getattr(gm, node.target)
65
+ sub_graph = sub_gm.graph
66
+ nodes_map(
67
+ sub_graph.nodes,
68
+ lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
69
+ )
70
+ node_inline_(node)
71
+
72
+
73
+ def _sequential_split_and_maybe_inline_subgraphs(
74
+ gm: torch.fx.GraphModule, graph_signature
75
+ ):
76
+ """
77
+ Helper function for replace_set_grad_with_hop_pass().
78
+ Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
79
+ For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
80
+ back into the parent graph module.
81
+ """
82
+ need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes)
83
+ if not need_replacing:
84
+ return gm, graph_signature
85
+
86
+ # sequential_split returns a new graph module that could have different output
87
+ # args names. We need to fix the graph signature.
88
+ new_gm = sequential_split(gm, _is_set_grad_enabled_node)
89
+
90
+ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
91
+ if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
92
+ _replace_with_hop(node)
93
+ else:
94
+ _remove_set_grad_and_inline(node)
95
+
96
+ return _sequential_split_and_maybe_inline_subgraphs_helper(
97
+ new_gm, graph_signature, _maybe_inline_or_replace_with_hop
98
+ )
99
+
100
+
101
+ def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
102
+ """
103
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
104
+ then recursively call itself on each of the submodules.
105
+ """
106
+ return _replace_with_hop_pass_helper(
107
+ gm,
108
+ graph_signature,
109
+ _sequential_split_and_maybe_inline_subgraphs,
110
+ )
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Dict, Optional
3
+ import torch
4
+ from torch._ops import OpOverload, HigherOrderOperator
5
+ from torch._export.error import InternalError
6
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
7
+
8
+
9
+ __all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
10
+
11
+
12
+ _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: Dict[OpOverload, OpOverload] = {
13
+ torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
14
+ }
15
+
16
+
17
+ def is_view_op(schema: torch._C.FunctionSchema) -> bool:
18
+ if len(schema.arguments) == 0:
19
+ return False
20
+ alias_info = schema.arguments[0].alias_info
21
+ return (alias_info is not None) and (not alias_info.is_write)
22
+
23
+
24
+ def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
25
+ if is_view_op(schema) and schema.name.startswith("aten::"):
26
+ view_op_name = schema.name.split("::")[1]
27
+ view_op_overload = (
28
+ schema.overload_name
29
+ if schema.overload_name != ""
30
+ else "default"
31
+ )
32
+ view_copy_op_name = view_op_name + "_copy"
33
+ if not hasattr(torch.ops.aten, view_copy_op_name):
34
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
35
+
36
+ view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
37
+
38
+ if not hasattr(view_copy_op_overload_packet, view_op_overload):
39
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
40
+
41
+ return getattr(view_copy_op_overload_packet, view_op_overload)
42
+
43
+ return None
44
+
45
+
46
+ class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
47
+ """
48
+ Our backend expects pure functional operators. For efficiency
49
+ purposes, we keep view ops around while functionalizing the exported
50
+ program. This pass replaces view ops with view copy ops for backends that
51
+ need AOT memory planning.
52
+ """
53
+ def call_operator(self, op, args, kwargs, meta):
54
+ if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
55
+ return super().call_operator(
56
+ (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
57
+ )
58
+
59
+ if isinstance(op, HigherOrderOperator):
60
+ return super().call_operator(op, args, kwargs, meta)
61
+
62
+ if view_copy_op := get_view_copy_of_view_op(op._schema):
63
+ return super().call_operator(view_copy_op, args, kwargs, meta)
64
+
65
+ return super().call_operator(op, args, kwargs, meta)
.venv/lib/python3.11/site-packages/torch/_export/passes/replace_with_hop_pass_util.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import contextlib
4
+ import copy
5
+ import operator
6
+ from typing import Callable
7
+
8
+ import torch
9
+ from torch._ops import HigherOrderOperator
10
+
11
+ from ..utils import node_replace_, nodes_map
12
+
13
+
14
+ def _replace_with_hop_helper(
15
+ node: torch.fx.Node,
16
+ enter_block_node: torch.fx.Node,
17
+ node_filter: Callable,
18
+ wrap_hoo: HigherOrderOperator,
19
+ ):
20
+ graph: torch.fx.Graph = node.graph
21
+ gm: torch.fx.GraphModule = graph.owning_module
22
+ assert isinstance(node.target, str)
23
+ sub_gm = getattr(gm, node.target)
24
+
25
+ def set_hoo_node_meta(call_func_node):
26
+ call_func_node.meta["nn_module_stack"] = copy.copy(
27
+ enter_block_node.meta.get("nn_module_stack", {})
28
+ )
29
+ call_func_node.meta["torch_fn"] = (
30
+ f"{wrap_hoo.__name__}",
31
+ f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
32
+ )
33
+ if isinstance(output_args, (tuple, list)):
34
+ call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args)
35
+ elif isinstance(output_args, torch.fx.Node):
36
+ call_func_node.meta["val"] = (output_args.meta["val"],)
37
+
38
+ with graph.inserting_before(node):
39
+ get_attr_node = graph.get_attr(node.target)
40
+ get_attr_node.meta["nn_module_stack"] = copy.copy(
41
+ enter_block_node.meta.get("nn_module_stack", {})
42
+ )
43
+ output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
44
+ # Split_module pass intentially doesn't add output node
45
+ # if the graph doesn't return anything.
46
+ # TODO (tmanlaibaatar) Figure out if this is right behaviour
47
+ # for split_module
48
+ if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
49
+ output_node = None
50
+ if output_node is not None:
51
+ assert len(output_node.args) == 1
52
+ output_args = output_node.args[0]
53
+ enter_block_node_args = enter_block_node.args
54
+ if isinstance(output_args, (tuple, list)):
55
+ call_func_node = graph.call_function(
56
+ wrap_hoo,
57
+ (*enter_block_node_args, get_attr_node, *node.args),
58
+ {},
59
+ )
60
+ # Create the metadata
61
+ set_hoo_node_meta(call_func_node)
62
+ node_replace_(node, call_func_node)
63
+
64
+ # Rename the name of getitem nodes to the actual name of its contents
65
+ # for passing verifier and better readability, also propagate metadata
66
+ for get_item_node in call_func_node.users.keys():
67
+ idx: int = get_item_node.args[1] # type: ignore[assignment]
68
+ output_node = output_args[idx]
69
+ get_item_node._rename(output_node.name)
70
+ get_item_node.meta = output_node.meta
71
+
72
+ elif isinstance(output_args, torch.fx.Node):
73
+ call_func_node = graph.create_node(
74
+ "call_function",
75
+ wrap_hoo,
76
+ (*enter_block_node_args, get_attr_node, *node.args),
77
+ {},
78
+ output_args.name,
79
+ )
80
+ # Modify the subgraph to output a singleton list.
81
+ output_node.args = ((output_args,),)
82
+ # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph.
83
+ get_item_node = graph.create_node(
84
+ "call_function",
85
+ operator.getitem,
86
+ (call_func_node, 0),
87
+ {},
88
+ )
89
+ # Create the metadata
90
+ get_item_node.meta = output_args.meta
91
+ set_hoo_node_meta(call_func_node)
92
+ node_replace_(node, get_item_node)
93
+ else:
94
+ raise NotImplementedError(
95
+ f"repalce_with_hop_pass doesnt' support output type {type(output_args)}"
96
+ )
97
+ else:
98
+ # TODO (shangdiy): remove this line, since the export graph can be non-functional
99
+ node.graph.erase_node(node)
100
+
101
+
102
+ def _sequential_split_and_maybe_inline_subgraphs_helper(
103
+ new_gm: torch.fx.GraphModule,
104
+ graph_signature,
105
+ maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None],
106
+ ):
107
+ """
108
+ Helper function for replacing graph nodse with higher order nodes.
109
+ For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls
110
+ back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`.
111
+ """
112
+ # new_gm is a new graph module that could have different output args names.
113
+ # We need to fix the graph signature.
114
+ replace_ctx = contextlib.nullcontext()
115
+ new_signature = None
116
+ if graph_signature is not None:
117
+ # Cannot deep copy a real ScriptObject, which is referenced
118
+ # in the FakeScriptObject. Copy should be good enough to guard
119
+ # against accidental mutation to original graph_signature.
120
+ new_signature = copy.copy(graph_signature)
121
+ new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output")))
122
+ assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len(
123
+ new_signature.output_specs
124
+ )
125
+ for arg_node, out_spec in zip(
126
+ new_gm_out_node.args[0], new_signature.output_specs
127
+ ):
128
+ if arg_node is None:
129
+ assert out_spec.arg.value is None
130
+ elif (
131
+ isinstance(arg_node, torch.fx.Node)
132
+ and out_spec.arg.name != arg_node.name
133
+ ):
134
+ out_spec.arg.name = arg_node.name
135
+
136
+ replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment]
137
+
138
+ with replace_ctx:
139
+ nodes_map(
140
+ list(new_gm.graph.nodes),
141
+ lambda node: (
142
+ maybe_inline_or_replace_with_hop(node)
143
+ if node.op == "call_module"
144
+ else node
145
+ ),
146
+ )
147
+ new_gm.recompile()
148
+ return new_gm, new_signature
149
+
150
+
151
+ def _replace_with_hop_pass_helper(
152
+ gm: torch.fx.GraphModule,
153
+ graph_signature,
154
+ sequential_split_and_maybe_inline_subgraphs: Callable,
155
+ ):
156
+ """
157
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
158
+ then recursively call itself on each of the submodules.
159
+ """
160
+ new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs(
161
+ gm, graph_signature
162
+ )
163
+ # recursively call
164
+ for node in new_gm.graph.nodes:
165
+ if node.op == "get_attr":
166
+ subgm = getattr(new_gm, node.target)
167
+ if not isinstance(subgm, torch.fx.GraphModule):
168
+ continue
169
+ new_subgm, _ = _replace_with_hop_pass_helper(
170
+ subgm,
171
+ None,
172
+ sequential_split_and_maybe_inline_subgraphs,
173
+ )
174
+ setattr(new_gm, node.target, new_subgm)
175
+
176
+ new_gm.recompile()
177
+ new_gm.graph.lint()
178
+ return new_gm, new_signature