koichi12 commited on
Commit
1bc1bad
·
verified ·
1 Parent(s): 7469295

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .venv/lib/python3.11/site-packages/torch/_export/error.py +56 -0
  3. .venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py +0 -0
  4. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc +3 -0
  9. .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py +15 -0
  11. .venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py +321 -0
  12. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc +0 -0
  15. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc +0 -0
  16. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/fx/__init__.py +89 -0
  22. .venv/lib/python3.11/site-packages/torch/fx/__init__.pyi +15 -0
  23. .venv/lib/python3.11/site-packages/torch/fx/_compatibility.py +36 -0
  24. .venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py +185 -0
  25. .venv/lib/python3.11/site-packages/torch/fx/_pytree.py +103 -0
  26. .venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py +1290 -0
  27. .venv/lib/python3.11/site-packages/torch/fx/_utils.py +63 -0
  28. .venv/lib/python3.11/site-packages/torch/fx/annotate.py +32 -0
  29. .venv/lib/python3.11/site-packages/torch/fx/config.py +6 -0
  30. .venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py +0 -0
  31. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py +27 -0
  50. .venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py +88 -0
.gitattributes CHANGED
@@ -126,3 +126,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
126
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
128
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
 
 
126
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
127
  .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
128
  .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_ops.so.9 filter=lfs diff=lfs merge=lfs -text
129
+ .venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/torch/_export/error.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class ExportErrorType(Enum):
5
+ # User providing invalid inputs to either tracer, or other public facing APIs
6
+ INVALID_INPUT_TYPE = 1
7
+
8
+ # User returning values from their models that we don't support.
9
+ INVALID_OUTPUT_TYPE = 2
10
+
11
+ # Generated IR does not conform to Export IR Specification.
12
+ VIOLATION_OF_SPEC = 3
13
+
14
+ # User's code contains types and functionalities we don't support.
15
+ NOT_SUPPORTED = 4
16
+
17
+ # User's code didn't provide necessary details for us to successfully trace and export.
18
+ # For example, we use a lot of decorators and ask users to annotate their model.
19
+ MISSING_PROPERTY = 5
20
+
21
+ # User is using an API without proper initialization step.
22
+ UNINITIALIZED = 6
23
+
24
+
25
+ def internal_assert(pred: bool, assert_msg: str) -> None:
26
+ """
27
+ This is exir's custom assert method. It internally just throws InternalError.
28
+ Note that the sole purpose is to throw our own error while maintaining similar syntax
29
+ as python assert.
30
+ """
31
+
32
+ if not pred:
33
+ raise InternalError(assert_msg)
34
+
35
+
36
+ class InternalError(Exception):
37
+ """
38
+ Raised when an internal invariance is violated in EXIR stack.
39
+ Should hint users to report a bug to dev and expose the original
40
+ error message.
41
+ """
42
+
43
+ def __init__(self, message: str) -> None:
44
+ super().__init__(message)
45
+
46
+
47
+ class ExportError(Exception):
48
+ """
49
+ This type of exception is raised for errors that are directly caused by the user
50
+ code. In general, user errors happen during model authoring, tracing, using our public
51
+ facing APIs, and writing graph passes.
52
+ """
53
+
54
+ def __init__(self, error_code: ExportErrorType, message: str) -> None:
55
+ prefix = f"[{error_code}]: "
56
+ super().__init__(prefix + message)
.venv/lib/python3.11/site-packages/torch/_export/serde/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/aoti_schema.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/dynamic_shapes.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema.cpython-311.pyc ADDED
Binary file (17.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/schema_check.cpython-311.pyc ADDED
Binary file (16.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/serialize.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc131857ed1d25d734bce65ed9c8acad8c38ffb2614c7fcf51f2cbfebac196a1
3
+ size 164473
.venv/lib/python3.11/site-packages/torch/_export/serde/__pycache__/union.cpython-311.pyc ADDED
Binary file (5.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_export/serde/aoti_schema.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from torch._export.serde.schema import Node
5
+
6
+
7
+ @dataclass
8
+ class ExternKernelNode:
9
+ name: str
10
+ node: Node
11
+
12
+
13
+ @dataclass
14
+ class ExternKernelNodes:
15
+ nodes: List[ExternKernelNode]
.venv/lib/python3.11/site-packages/torch/_export/serde/dynamic_shapes.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch._dynamo.exc import UserError, UserErrorType
6
+ from torch.export.dynamic_shapes import (
7
+ _check_dynamic_shapes,
8
+ _DerivedDim,
9
+ _Dim,
10
+ _DimHint,
11
+ _tree_map_with_path,
12
+ Dim,
13
+ )
14
+ from torch.utils._pytree import tree_map
15
+
16
+ from .serialize import _dataclass_to_dict
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class RootDim:
21
+ """
22
+ This represents a _Dim object.
23
+ """
24
+
25
+ min: int
26
+ max: Union[int, None]
27
+ derived: List[str]
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class DynamicShapesSpec:
32
+ """
33
+ This stores a dynamic_shapes spec for de/serialization.
34
+ """
35
+
36
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None]
37
+ dims: Dict[str, RootDim]
38
+
39
+
40
+ def _postprocess_serialized_shapes(
41
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
42
+ dims: Dict[str, Dict[str, Union[int, List[str], None]]],
43
+ to_dict: Optional[bool] = False,
44
+ ) -> Union[DynamicShapesSpec, Dict[str, Any]]:
45
+ """
46
+ Sorts dims and dumps to dictionary format.
47
+ """
48
+ from torch.utils._sympy.numbers import int_oo
49
+
50
+ dims = {
51
+ k: RootDim(
52
+ min=v["min"], # type: ignore[arg-type]
53
+ max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type]
54
+ derived=sorted(v["derived"]), # type: ignore[arg-type]
55
+ )
56
+ for k, v in sorted(dims.items())
57
+ }
58
+ spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
59
+ if to_dict:
60
+ return _dataclass_to_dict(spec)
61
+ else:
62
+ return spec
63
+
64
+
65
+ def _dump_dynamic_shapes(
66
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
67
+ args: Tuple[Any],
68
+ kwargs: Optional[Dict[str, Any]] = None,
69
+ to_dict: Optional[bool] = False,
70
+ ) -> Union[DynamicShapesSpec, Dict[str, Any]]:
71
+ """
72
+ Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
73
+ Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
74
+ Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
75
+
76
+ dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
77
+ - Each tensor input is represented with a list of values, non-tensor inputs with None.
78
+ - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
79
+ - static dimensions are represented with ints.
80
+
81
+ dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
82
+
83
+ For example:
84
+ ```
85
+ dx = Dim("dx", min=4, max=16)
86
+ dy = dx + 1
87
+
88
+ inputs = (
89
+ [
90
+ torch.randn(4, 4),
91
+ torch.randn(5, 4),
92
+ ],
93
+ torch.randn(4),
94
+ torch.randn(4, 4),
95
+ "hello",
96
+ )
97
+ dynamic_shapes = {
98
+ "a": [
99
+ (dx, 4),
100
+ (dy, 4),
101
+ ],
102
+ "b": (Dim.STATIC,),
103
+ "c": None,
104
+ "d": None,
105
+ }
106
+ out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
107
+ ```
108
+ would generate the following output:
109
+ ```
110
+ {
111
+ 'dynamic_shapes': (
112
+ [
113
+ ['dx', 4],
114
+ ['dx + 1', 4],
115
+ ],
116
+ ['_DimHint.STATIC'],
117
+ ['_DimHint.STATIC', '_DimHint.STATIC'],
118
+ None,
119
+ ),
120
+ 'dims': {
121
+ 'dx': {
122
+ 'min': 4,
123
+ 'max': 16,
124
+ 'derived': ['dx + 1'],
125
+ },
126
+ },
127
+ }
128
+ ```
129
+ """
130
+ dims: Dict[str, Dict[str, Any]] = {}
131
+
132
+ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
133
+ """
134
+ Helps standardize the dynamic_shapes tree structure we serialize,
135
+ returning lists for each tensor shape, handling tensor-level Nones.
136
+ """
137
+ if not isinstance(tensor, torch.Tensor):
138
+ return None
139
+ if shape is None:
140
+ return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined]
141
+
142
+ out = []
143
+ if isinstance(shape, dict):
144
+ for i, s in enumerate(tensor.shape):
145
+ out.append(s if shape.get(i) is None else shape.get(i))
146
+ else:
147
+ assert isinstance(shape, (tuple, list))
148
+ for i, s in enumerate(tensor.shape):
149
+ out.append(s if shape[i] is None else shape[i])
150
+ return out
151
+
152
+ def _track_dim_from_dims(
153
+ val: Union[None, int, _DimHint, _Dim]
154
+ ) -> Union[None, int, str]:
155
+ """
156
+ Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
157
+ """
158
+ if val is None or isinstance(val, int): # non-tensor input or static
159
+ return val
160
+ if isinstance(val, _DimHint): # store enum as string
161
+ return val.__class__.__name__ + "." + val.name
162
+
163
+ assert isinstance(val, _Dim)
164
+
165
+ # track root dim
166
+ root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
167
+ if root.__name__ not in dims:
168
+ dims[root.__name__] = {
169
+ "min": root.min,
170
+ "max": root.max,
171
+ "derived": set(),
172
+ }
173
+
174
+ # track derived dims
175
+ if isinstance(val, _DerivedDim):
176
+ dims[root.__name__]["derived"].add(val.__name__)
177
+
178
+ return val.__name__
179
+
180
+ if dynamic_shapes is None:
181
+ return {"dynamic_shapes": None, "dims": {}}
182
+
183
+ # convert to tuple of specs, for each arg/kwarg
184
+ kwargs = kwargs or {}
185
+ if isinstance(dynamic_shapes, dict):
186
+ dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
187
+ dynamic_shapes = tuple(dynamic_shapes)
188
+ combined_args = tuple(args) + tuple(kwargs.values())
189
+
190
+ # run same check when we're processing shapes for export - is this too lazy?
191
+ _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type]
192
+
193
+ tree_shapes = _tree_map_with_path(
194
+ _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
195
+ )
196
+ serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
197
+ return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
198
+
199
+
200
+ def _load_dynamic_shapes(
201
+ spec: Union[DynamicShapesSpec, Dict[str, Any]],
202
+ from_dict: Optional[bool] = False,
203
+ ) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]:
204
+ """
205
+ Utility function for dynamic shapes serialization.
206
+ Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
207
+ """
208
+ import sympy
209
+
210
+ from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
211
+
212
+ if from_dict:
213
+ if not isinstance(spec, dict):
214
+ raise UserError(
215
+ UserErrorType.INVALID_INPUT,
216
+ f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
217
+ )
218
+ if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
219
+ raise UserError(
220
+ UserErrorType.INVALID_INPUT,
221
+ "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
222
+ f"instead found {spec.keys()}",
223
+ )
224
+ dims = {}
225
+ for k, v in spec["dims"].items():
226
+ if not isinstance(k, str):
227
+ raise UserError(
228
+ UserErrorType.INVALID_INPUT,
229
+ f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
230
+ )
231
+ if sorted(v.keys()) != ["derived", "max", "min"]:
232
+ raise UserError(
233
+ UserErrorType.INVALID_INPUT,
234
+ f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
235
+ f"instead found {v.keys()}",
236
+ )
237
+ if not isinstance(v["min"], int):
238
+ raise UserError(
239
+ UserErrorType.INVALID_INPUT,
240
+ f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
241
+ )
242
+ if not isinstance(v["max"], int) or v["max"] is None:
243
+ raise UserError(
244
+ UserErrorType.INVALID_INPUT,
245
+ f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
246
+ )
247
+ if not isinstance(v["derived"], list) or any(
248
+ not isinstance(d, str) for d in v["derived"]
249
+ ):
250
+ raise UserError(
251
+ UserErrorType.INVALID_INPUT,
252
+ "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
253
+ f"got {k}: {v['derived']}",
254
+ )
255
+ dims[k] = RootDim(**v)
256
+ dynamic_shapes = spec["dynamic_shapes"]
257
+ else:
258
+ if not isinstance(spec, DynamicShapesSpec):
259
+ raise UserError(
260
+ UserErrorType.INVALID_INPUT,
261
+ f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
262
+ )
263
+ dims = spec.dims
264
+ dynamic_shapes = spec.dynamic_shapes
265
+
266
+ if dynamic_shapes is None:
267
+ return None
268
+
269
+ dim_cache = {}
270
+ for name, info in dims.items():
271
+ symbol = sympy.sympify(name)
272
+ if not isinstance(symbol, sympy.Symbol):
273
+ raise UserError(
274
+ UserErrorType.INVALID_INPUT,
275
+ f"Expected `spec['dims']` keys to be symbols, got {name}",
276
+ )
277
+ dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim
278
+ for _expr in info.derived:
279
+ expr = sympy.sympify(_expr)
280
+ if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
281
+ raise UserError(
282
+ UserErrorType.INVALID_INPUT,
283
+ f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
284
+ )
285
+ if not _is_supported_equivalence(expr):
286
+ raise UserError(
287
+ UserErrorType.INVALID_INPUT,
288
+ f"Expected derived expressions to be linear expressions, got {expr}",
289
+ )
290
+ modulus, remainder = sympy.polys.polytools.div(expr, symbol)
291
+ ddim = dim_cache[name]
292
+ if modulus != 1:
293
+ ddim = int(modulus) * ddim
294
+ if remainder != 0:
295
+ ddim = ddim + int(remainder)
296
+ dim_cache[_expr] = ddim # cache derived dims
297
+
298
+ def deserialize_shape(
299
+ val: Union[None, int, str]
300
+ ) -> Union[None, int, _Dim, _DimHint]:
301
+ if val is None or isinstance(val, int):
302
+ return val
303
+ elif val == "_DimHint.AUTO":
304
+ return _DimHint.AUTO
305
+ elif val == "_DimHint.STATIC":
306
+ return _DimHint.STATIC
307
+ if not isinstance(val, str):
308
+ raise UserError(
309
+ UserErrorType.INVALID_INPUT,
310
+ "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
311
+ f" or derived expressions, got {val}",
312
+ )
313
+ if val not in dim_cache:
314
+ raise UserError(
315
+ UserErrorType.INVALID_INPUT,
316
+ "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
317
+ f"got {val} which is not in {dims.keys()}",
318
+ )
319
+ return dim_cache[val]
320
+
321
+ return tree_map(deserialize_shape, dynamic_shapes)
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.21 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/closure.cpython-311.pyc ADDED
Binary file (8.08 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/computation.cpython-311.pyc ADDED
Binary file (1.57 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.15 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/debug.cpython-311.pyc ADDED
Binary file (1.31 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/device_context.cpython-311.pyc ADDED
Binary file (1.66 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (1.39 kB). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-311.pyc ADDED
Binary file (1.06 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/__init__.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r'''
2
+ FX is a toolkit for developers to use to transform ``nn.Module``
3
+ instances. FX consists of three main components: a **symbolic tracer,**
4
+ an **intermediate representation**, and **Python code generation**. A
5
+ demonstration of these components in action:
6
+
7
+ ::
8
+
9
+ import torch
10
+ # Simple module for demonstration
11
+ class MyModule(torch.nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+ self.param = torch.nn.Parameter(torch.rand(3, 4))
15
+ self.linear = torch.nn.Linear(4, 5)
16
+
17
+ def forward(self, x):
18
+ return self.linear(x + self.param).clamp(min=0.0, max=1.0)
19
+
20
+ module = MyModule()
21
+
22
+ from torch.fx import symbolic_trace
23
+ # Symbolic tracing frontend - captures the semantics of the module
24
+ symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
25
+
26
+ # High-level intermediate representation (IR) - Graph representation
27
+ print(symbolic_traced.graph)
28
+ """
29
+ graph():
30
+ %x : [num_users=1] = placeholder[target=x]
31
+ %param : [num_users=1] = get_attr[target=param]
32
+ %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
33
+ %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
34
+ %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
35
+ return clamp
36
+ """
37
+
38
+ # Code generation - valid Python code
39
+ print(symbolic_traced.code)
40
+ """
41
+ def forward(self, x):
42
+ param = self.param
43
+ add = x + param; x = param = None
44
+ linear = self.linear(add); add = None
45
+ clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
46
+ return clamp
47
+ """
48
+
49
+ The **symbolic tracer** performs "symbolic execution" of the Python
50
+ code. It feeds fake values, called Proxies, through the code. Operations
51
+ on theses Proxies are recorded. More information about symbolic tracing
52
+ can be found in the :func:`symbolic_trace` and :class:`Tracer`
53
+ documentation.
54
+
55
+ The **intermediate representation** is the container for the operations
56
+ that were recorded during symbolic tracing. It consists of a list of
57
+ Nodes that represent function inputs, callsites (to functions, methods,
58
+ or :class:`torch.nn.Module` instances), and return values. More information
59
+ about the IR can be found in the documentation for :class:`Graph`. The
60
+ IR is the format on which transformations are applied.
61
+
62
+ **Python code generation** is what makes FX a Python-to-Python (or
63
+ Module-to-Module) transformation toolkit. For each Graph IR, we can
64
+ create valid Python code matching the Graph's semantics. This
65
+ functionality is wrapped up in :class:`GraphModule`, which is a
66
+ :class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
67
+ ``forward`` method generated from the Graph.
68
+
69
+ Taken together, this pipeline of components (symbolic tracing ->
70
+ intermediate representation -> transforms -> Python code generation)
71
+ constitutes the Python-to-Python transformation pipeline of FX. In
72
+ addition, these components can be used separately. For example,
73
+ symbolic tracing can be used in isolation to capture a form of
74
+ the code for analysis (and not transformation) purposes. Code
75
+ generation can be used for programmatically generating models, for
76
+ example from a config file. There are many uses for FX!
77
+
78
+ Several example transformations can be found at the
79
+ `examples <https://github.com/pytorch/examples/tree/master/fx>`__
80
+ repository.
81
+ '''
82
+
83
+ from .graph_module import GraphModule
84
+ from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta
85
+ from .graph import Graph, CodeGen
86
+ from .node import Node, map_arg, has_side_effect
87
+ from .proxy import Proxy
88
+ from .interpreter import Interpreter as Interpreter, Transformer as Transformer
89
+ from .subgraph_rewriter import replace_pattern
.venv/lib/python3.11/site-packages/torch/fx/__init__.pyi ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.fx._symbolic_trace import (
2
+ symbolic_trace as symbolic_trace,
3
+ Tracer as Tracer,
4
+ wrap as wrap,
5
+ )
6
+ from torch.fx.graph import Graph as Graph
7
+ from torch.fx.graph_module import GraphModule as GraphModule
8
+ from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer
9
+ from torch.fx.node import (
10
+ has_side_effect as has_side_effect,
11
+ map_arg as map_arg,
12
+ Node as Node,
13
+ )
14
+ from torch.fx.proxy import Proxy as Proxy
15
+ from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern
.venv/lib/python3.11/site-packages/torch/fx/_compatibility.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Callable, TypeVar
2
+ import textwrap
3
+
4
+ _BACK_COMPAT_OBJECTS : Dict[Any, None] = {}
5
+ _MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {}
6
+
7
+ _T = TypeVar("_T")
8
+
9
+ def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]:
10
+ if is_backward_compatible:
11
+
12
+ def mark_back_compat(fn: _T) -> _T:
13
+ docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
14
+ docstring += """
15
+ .. note::
16
+ Backwards-compatibility for this API is guaranteed.
17
+ """
18
+ fn.__doc__ = docstring
19
+ _BACK_COMPAT_OBJECTS.setdefault(fn)
20
+ _MARKED_WITH_COMPATIBILITY.setdefault(fn)
21
+ return fn
22
+
23
+ return mark_back_compat
24
+ else:
25
+
26
+ def mark_not_back_compat(fn: _T) -> _T:
27
+ docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '')
28
+ docstring += """
29
+ .. warning::
30
+ This API is experimental and is *NOT* backward-compatible.
31
+ """
32
+ fn.__doc__ = docstring
33
+ _MARKED_WITH_COMPATIBILITY.setdefault(fn)
34
+ return fn
35
+
36
+ return mark_not_back_compat
.venv/lib/python3.11/site-packages/torch/fx/_lazy_graph_module.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from contextlib import contextmanager
3
+
4
+ from torch.fx import GraphModule
5
+ from torch.fx.graph_module import (
6
+ _format_import_block,
7
+ reduce_graph_module,
8
+ reduce_package_graph_module,
9
+ )
10
+ from torch.package import PackageExporter, sys_importer
11
+
12
+ from ._compatibility import compatibility
13
+
14
+
15
+ _use_lazy_graph_module_flag = False
16
+ _force_skip_lazy_graph_module_flag = False
17
+
18
+
19
+ @compatibility(is_backward_compatible=False)
20
+ @contextmanager
21
+ def _force_skip_lazy_graph_module():
22
+ """
23
+ Skip using lazy graph module disregarding the setting of _use_lazy_graph_module.
24
+ Use to skip _LazyGraphModule when testing inductor torchscript related backend.
25
+
26
+ torch.jit.script a _LazyGraphModule results in following error:
27
+ https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69
28
+ """
29
+ try:
30
+ global _force_skip_lazy_graph_module_flag
31
+ prior = _force_skip_lazy_graph_module_flag
32
+ _force_skip_lazy_graph_module_flag = True
33
+ yield
34
+ finally:
35
+ _force_skip_lazy_graph_module_flag = prior
36
+
37
+
38
+ @compatibility(is_backward_compatible=False)
39
+ @contextmanager
40
+ def _use_lazy_graph_module(should_use: bool):
41
+ try:
42
+ global _use_lazy_graph_module_flag
43
+ prior = _use_lazy_graph_module_flag
44
+ _use_lazy_graph_module_flag = (
45
+ should_use and not _force_skip_lazy_graph_module_flag
46
+ )
47
+ yield
48
+ finally:
49
+ _use_lazy_graph_module_flag = prior
50
+
51
+
52
+ @compatibility(is_backward_compatible=False)
53
+ def _get_graph_module_cls():
54
+ return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule
55
+
56
+
57
+ def _make_graph_module(*args, graph_module_cls=None, **kwargs):
58
+ if graph_module_cls is None:
59
+ graph_module_cls = _get_graph_module_cls()
60
+
61
+ return graph_module_cls(*args, **kwargs)
62
+
63
+
64
+ @compatibility(is_backward_compatible=False)
65
+ class _LazyGraphModule(GraphModule):
66
+ """
67
+ The main difference between _LazyGraphModule and GraphModule is how recompile happens.
68
+ GraphModule will do a 'recompile' call to generate python code and the forward method when it's
69
+ constructed. Later on if the graph get updated, recompile method can be called again to refresh
70
+ the saved python code and forward method.
71
+
72
+ However in some cases especially in inductor, the recompilation can be a waste since we never
73
+ check the python code for the graph module or call its forward method. A few more concreate
74
+ examples regarding pattern matching fx passes in inductor:
75
+ 1. some passes will update the graph to be compiled and then call recompile on the GraphModule.
76
+ 2. some passes will trace small pattern function to search it in the graph being compiled and
77
+ replace the match with the traced graph of a replacement function. The pattern graph and
78
+ replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile
79
+ for them in GraphModule.__init__ is also a waste of time.
80
+
81
+ However simply skip calling GraphModule.recompile in these scenarios is also dangeruous.
82
+ People may want to check the python code or call the GraphModule's forward method for debugging purposes.
83
+
84
+ The way _LazyGraphModule solves it is, we override the recompile method to just mark the
85
+ need for recompilation but does not do the actual recompilation. Later on if people really
86
+ access the compiled python code or call the GraphModule's forward method, we do the real
87
+ recompilation.
88
+ """
89
+
90
+ @classmethod
91
+ def from_graphmodule(cls, gm: GraphModule):
92
+ if isinstance(gm, _LazyGraphModule):
93
+ return gm
94
+ else:
95
+ return _LazyGraphModule(gm, gm.graph)
96
+
97
+ @staticmethod
98
+ def force_recompile(gm):
99
+ """
100
+ Sometimes we need force a recompile as a workaround
101
+ - we want to do the real recompilation before symbolic_trace to avoid error:
102
+ https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
103
+ """
104
+ if isinstance(gm, _LazyGraphModule):
105
+ gm.real_recompile()
106
+
107
+ def real_recompile(self):
108
+ if self._needs_recompile():
109
+ self._real_recompile()
110
+
111
+ @classmethod
112
+ def _needs_recompile(cls):
113
+ return cls.forward is cls._lazy_forward
114
+
115
+ def _lazy_forward(self, *args, **kwargs):
116
+ # Call self.real_recompile() rather than self._real_recompile() here.
117
+ # The _lazy_forward method may be saved and call repeatedly.
118
+ # Calling self.real_recompile can make sure we skip recompilation if
119
+ # we have already done so.
120
+ self.real_recompile()
121
+ assert not self._needs_recompile()
122
+
123
+ # call `__call__` rather than 'forward' since recompilation may
124
+ # install a wrapper for `__call__` to provide a customized error
125
+ # message.
126
+ return self(*args, **kwargs)
127
+
128
+ forward = _lazy_forward
129
+
130
+ # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__,
131
+ # or __reduce__ by calling _real_recompile. But I don't find a good way
132
+ # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule
133
+ # will be used in torch::deploy. So it's skipped for now.
134
+
135
+ def __reduce_package__(self, exporter: PackageExporter):
136
+ """
137
+ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
138
+ than 'self.recompile' since for a _LazyGraphModule, self.recompile just
139
+ mark the need of recompilation and does not return the PythonCode object.
140
+ """
141
+ python_code = self._real_recompile()
142
+ dict_without_graph = self.__dict__.copy()
143
+ dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__
144
+ del dict_without_graph["_graph"]
145
+
146
+ generated_module_name = f"fx-generated._{exporter.get_unique_id()}"
147
+ import_block = _format_import_block(python_code.globals, exporter.importer)
148
+ module_code = import_block + self.code
149
+ exporter.save_source_string(generated_module_name, module_code)
150
+ return (
151
+ reduce_package_graph_module,
152
+ (dict_without_graph, generated_module_name),
153
+ )
154
+
155
+ def __reduce__(self):
156
+ """
157
+ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather
158
+ than 'self.recompile' since for a _LazyGraphModule, self.recompile just
159
+ mark the need of recompilation and does not return the PythonCode object.
160
+ """
161
+ python_code = self._real_recompile()
162
+ dict_without_graph = self.__dict__.copy()
163
+ import_block = _format_import_block(python_code.globals, sys_importer)
164
+ del dict_without_graph["_graph"]
165
+ return (reduce_graph_module, (dict_without_graph, import_block))
166
+
167
+ def _real_recompile(self):
168
+ return super().recompile()
169
+
170
+ @classmethod
171
+ def recompile(cls):
172
+ cls.forward = cls._lazy_forward
173
+
174
+ @property
175
+ def code(self) -> str:
176
+ self.real_recompile()
177
+ return super().code
178
+
179
+ def __str__(self) -> str:
180
+ """
181
+ str(GraphModule) will access the _code attribute. Make sure recompile
182
+ happens so _code attribute is available.
183
+ """
184
+ self.real_recompile()
185
+ return super().__str__()
.venv/lib/python3.11/site-packages/torch/fx/_pytree.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections import namedtuple
3
+ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type
4
+
5
+ import torch.return_types
6
+ from torch.utils._pytree import PyTree, TreeSpec
7
+
8
+
9
+ FlattenFuncSpec = Callable[[PyTree, TreeSpec], List]
10
+ FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool]
11
+
12
+ SUPPORTED_NODES: Dict[Type[Any], FlattenFuncSpec] = {}
13
+ SUPPORTED_NODES_EXACT_MATCH: Dict[Type[Any], Optional[FlattenFuncExactMatchSpec]] = {}
14
+
15
+
16
+ def register_pytree_flatten_spec(
17
+ cls: Type[Any],
18
+ flatten_fn_spec: FlattenFuncSpec,
19
+ flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None,
20
+ ) -> None:
21
+ SUPPORTED_NODES[cls] = flatten_fn_spec
22
+ SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec
23
+
24
+
25
+ def tree_flatten_spec(
26
+ pytree: PyTree,
27
+ spec: TreeSpec,
28
+ exact_structural_match=False,
29
+ ) -> List[Any]:
30
+ if spec.is_leaf():
31
+ return [pytree]
32
+ if spec.type not in SUPPORTED_NODES:
33
+ raise RuntimeError(
34
+ f"{type(pytree)} does not have a flatten_fn_spec associated with it. Please register one with "
35
+ "torch.fx._pytree.register_pytree_flatten_spec. If you have serialized your model, make "
36
+ "sure that any custom pytrees have been registered before loading it.",
37
+ )
38
+ flatten_fn_spec = SUPPORTED_NODES[spec.type]
39
+ child_pytrees = flatten_fn_spec(pytree, spec)
40
+ if exact_structural_match:
41
+ flatten_fn_exact_match_spec = SUPPORTED_NODES_EXACT_MATCH[spec.type]
42
+ if flatten_fn_exact_match_spec and not flatten_fn_exact_match_spec(
43
+ pytree,
44
+ spec,
45
+ ):
46
+ raise RuntimeError(f"Cannot flatten pytree {pytree}, given spec: {spec}")
47
+ result = []
48
+ for child, child_spec in zip(child_pytrees, spec.children_specs):
49
+ flat = tree_flatten_spec(child, child_spec, exact_structural_match)
50
+ result += flat
51
+ return result
52
+
53
+
54
+ def _dict_flatten_spec(d: Dict[Any, Any], spec: TreeSpec) -> List[Any]:
55
+ return [d[k] for k in spec.context]
56
+
57
+
58
+ def _list_flatten_spec(d: List[Any], spec: TreeSpec) -> List[Any]:
59
+ return [d[i] for i in range(spec.num_children)]
60
+
61
+
62
+ def _tuple_flatten_spec(d: Tuple[Any], spec: TreeSpec) -> List[Any]:
63
+ return [d[i] for i in range(spec.num_children)]
64
+
65
+
66
+ def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> List[Any]:
67
+ return [d[i] for i in range(spec.num_children)]
68
+
69
+
70
+ def _dict_flatten_spec_exact_match(d: Dict[Any, Any], spec: TreeSpec) -> bool:
71
+ return len(d) == spec.num_children
72
+
73
+
74
+ def _list_flatten_spec_exact_match(d: List[Any], spec: TreeSpec) -> bool:
75
+ return len(d) == spec.num_children
76
+
77
+
78
+ def _tuple_flatten_spec_exact_match(d: Tuple[Any], spec: TreeSpec) -> bool:
79
+ return len(d) == spec.num_children
80
+
81
+
82
+ def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool:
83
+ return len(d) == spec.num_children
84
+
85
+
86
+ register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match)
87
+ register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match)
88
+ register_pytree_flatten_spec(
89
+ tuple,
90
+ _tuple_flatten_spec,
91
+ _tuple_flatten_spec_exact_match,
92
+ )
93
+ for return_type in torch.return_types.all_return_types:
94
+ register_pytree_flatten_spec(
95
+ return_type,
96
+ _tuple_flatten_spec,
97
+ _tuple_flatten_spec_exact_match,
98
+ )
99
+ register_pytree_flatten_spec(
100
+ namedtuple, # type: ignore[arg-type]
101
+ _namedtuple_flatten_spec,
102
+ _namedtuple_flatten_spec_exact_match,
103
+ )
.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py ADDED
@@ -0,0 +1,1290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import builtins
3
+ import copy
4
+ import contextlib
5
+ import functools
6
+ import inspect
7
+ import math
8
+ import os
9
+ import warnings
10
+ import collections
11
+ from itertools import chain
12
+ from types import CodeType, FunctionType, ModuleType
13
+ from typing import (
14
+ Any,
15
+ Callable,
16
+ Dict,
17
+ List,
18
+ NamedTuple,
19
+ Optional,
20
+ Set,
21
+ Tuple,
22
+ Type,
23
+ Union,
24
+ )
25
+
26
+ import torch
27
+ import torch.utils._pytree as pytree
28
+ from torch._C import ScriptObject # type: ignore[attr-defined]
29
+ from torch._library.fake_class_registry import FakeScriptObject
30
+
31
+ from ._compatibility import compatibility
32
+ from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph
33
+ from .graph_module import GraphModule
34
+ from ._lazy_graph_module import _make_graph_module
35
+ from .node import Argument, base_types, map_aggregate
36
+ from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager
37
+
38
+ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
39
+
40
+ # These need to run in global scope to handle nested calls correctly
41
+ _orig_module_call: Callable = torch.nn.Module.__call__
42
+ _orig_module_getattr: Callable = torch.nn.Module.__getattr__
43
+
44
+ _proxyable_classes: Dict[Type, None] = {}
45
+
46
+ _is_fx_tracing_flag = False
47
+
48
+
49
+ def is_fx_tracing():
50
+ return _is_fx_tracing_flag
51
+
52
+ @compatibility(is_backward_compatible=True)
53
+ class ProxyableClassMeta(type):
54
+ """
55
+ ProxyableClassMeta allows you to make construction of a given Python class
56
+ symbolically traceable. For example::
57
+
58
+ import torch
59
+ import torch.fx
60
+
61
+ class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
62
+ def __init__(self, left, right):
63
+ self.left, self.right = left, right
64
+
65
+ def add(self, other):
66
+ l = self.left + other.left
67
+ r = self.right + other.right
68
+ return TensorPair(l, r)
69
+
70
+ def mul(self, other):
71
+ l = self.left * other.left
72
+ r = self.right * other.right
73
+ return TensorPair(l, r)
74
+
75
+ def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
76
+ s = x.add(TensorPair(y, y))
77
+ return s.mul(x)
78
+
79
+ x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
80
+ y = torch.randn(5, 3)
81
+ ref_out = use_tensor_pair_ctor(x, y)
82
+
83
+ traced = torch.fx.symbolic_trace(use_tensor_pair_ctor)
84
+ print(traced.code)
85
+ '''
86
+ def forward(self, x : __main___TensorPair, y : torch.Tensor):
87
+ tensor_pair = __main___TensorPair(y, y); y = None
88
+ add = x.add(tensor_pair); tensor_pair = None
89
+ mul = add.mul(x); add = x = None
90
+ return mul
91
+ '''
92
+
93
+ From this example, we can see that construction of a class (``TensorPair``)
94
+ defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic
95
+ tracing.
96
+ """
97
+
98
+ def __init__(cls, name, bases, attrs):
99
+ _proxyable_classes.setdefault(cls)
100
+ super().__init__(name, bases, attrs)
101
+
102
+ def __call__(cls, *args, **kwargs):
103
+ instance = cls.__new__(cls) # type: ignore[call-overload]
104
+
105
+ if not is_fx_tracing():
106
+ cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
107
+ return instance
108
+
109
+ found_proxies = []
110
+
111
+ def check_proxy(a):
112
+ if isinstance(a, Proxy):
113
+ found_proxies.append(a)
114
+
115
+ map_aggregate(args, check_proxy)
116
+ map_aggregate(kwargs, check_proxy)
117
+
118
+ if len(found_proxies) != 0:
119
+ tracer = found_proxies[0].tracer
120
+ return tracer.create_proxy("call_function", cls, args, kwargs)
121
+ else:
122
+ cls.__init__(instance, *args, **kwargs) # type: ignore[misc]
123
+ return instance
124
+
125
+
126
+ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
127
+ co = fn.__code__
128
+ co_flags = co.co_flags & ~HAS_VARSTUFF
129
+ co_args: tuple
130
+ if hasattr(co, "co_qualname"):
131
+ # Python-3.11+ code signature
132
+ co_args = (
133
+ nargs,
134
+ 0,
135
+ 0,
136
+ co.co_nlocals,
137
+ co.co_stacksize,
138
+ co_flags,
139
+ co.co_code,
140
+ co.co_consts,
141
+ co.co_names,
142
+ co.co_varnames,
143
+ co.co_filename,
144
+ co.co_name,
145
+ co.co_qualname, # type: ignore[attr-defined]
146
+ co.co_firstlineno,
147
+ co.co_lnotab,
148
+ co.co_exceptiontable, # type: ignore[attr-defined]
149
+ co.co_freevars,
150
+ co.co_cellvars,
151
+ )
152
+ elif hasattr(co, "co_posonlyargcount"):
153
+ co_args = (
154
+ nargs,
155
+ 0,
156
+ 0,
157
+ co.co_nlocals,
158
+ co.co_stacksize,
159
+ co_flags,
160
+ co.co_code,
161
+ co.co_consts,
162
+ co.co_names,
163
+ co.co_varnames,
164
+ co.co_filename,
165
+ co.co_name,
166
+ co.co_firstlineno,
167
+ co.co_lnotab,
168
+ co.co_freevars,
169
+ co.co_cellvars,
170
+ )
171
+ else:
172
+ co_args = (
173
+ nargs,
174
+ 0,
175
+ co.co_nlocals,
176
+ co.co_stacksize,
177
+ co_flags,
178
+ co.co_code,
179
+ co.co_consts,
180
+ co.co_names,
181
+ co.co_varnames,
182
+ co.co_filename,
183
+ co.co_name,
184
+ co.co_firstlineno,
185
+ co.co_lnotab,
186
+ co.co_freevars,
187
+ co.co_cellvars,
188
+ )
189
+ new_code = CodeType(*co_args) # type: ignore[arg-type]
190
+ return FunctionType(
191
+ new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__
192
+ )
193
+
194
+ # we need to insert placeholder nodes for *args and **kwargs
195
+ # we can't call this function normally, otherwise it would try to unpack them
196
+ # instead, let's make python think that args and kwargs are normal variables
197
+
198
+
199
+ @compatibility(is_backward_compatible=False)
200
+ class PHBase:
201
+ """
202
+ Object representing an input placeholder to `concrete_args`
203
+ """
204
+
205
+ def __repr__(self):
206
+ return "PH"
207
+
208
+
209
+ PH = PHBase()
210
+
211
+
212
+ @compatibility(is_backward_compatible=False)
213
+ class PHWithMeta(PHBase):
214
+ """
215
+ Object representing an input placeholder to `concrete_args`
216
+ """
217
+ def __init__(self, ph_key: Optional[str] = None):
218
+ super().__init__()
219
+
220
+ # Provide a hey for user to identify placeholder node during analysis
221
+ self.ph_key = ph_key
222
+
223
+
224
+ def _transfer_attrs(fr, to):
225
+ for attr_name in dir(fr):
226
+ attr_val = getattr(fr, attr_name)
227
+ if (
228
+ not callable(attr_val)
229
+ and not attr_name.startswith("__")
230
+ and not hasattr(to, attr_name)
231
+ ):
232
+ setattr(to, attr_name, attr_val)
233
+
234
+
235
+ @compatibility(is_backward_compatible=True)
236
+ class Tracer(TracerBase):
237
+ # Reference: https://github.com/pytorch/pytorch/issues/54354
238
+ # The first line of this docstring overrides the one Sphinx generates for the
239
+ # documentation. We need it so that Sphinx doesn't leak `math`s path from the
240
+ # build environment (e.g. `<module 'math' from '/leaked/path').
241
+
242
+ """Tracer(autowrap_modules=(math,), autowrap_functions=())
243
+
244
+ ``Tracer`` is the class that implements the symbolic tracing functionality
245
+ of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent
246
+ to ``Tracer().trace(m)``.
247
+
248
+ Tracer can be subclassed to override various behaviors of the tracing
249
+ process. The different behaviors that can be overridden are described
250
+ in the docstrings of the methods on this class.
251
+ """
252
+
253
+ # Not checking BC on this API because the default value for `autowrap_modules`
254
+ # includes the local filepath to the `math` module, which would jitter
255
+ # across machines.
256
+ @compatibility(is_backward_compatible=True)
257
+ def __init__(
258
+ self,
259
+ autowrap_modules: Tuple[ModuleType] = (math,),
260
+ autowrap_functions: Tuple[Callable, ...] = (),
261
+ param_shapes_constant: bool = False,
262
+ ) -> None:
263
+ # This method's signature is overridden by the first line of this class'
264
+ # docstring. If this method's signature is modified, the signature that
265
+ # overrides it also should be modified accordingly.
266
+
267
+ """
268
+ Construct a Tracer object.
269
+
270
+ Args:
271
+
272
+ autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`,
273
+ Python modules whose functions should be wrapped automatically
274
+ without needing to use fx.wrap(). Backward-compatibility for
275
+ this parameter is guaranteed.
276
+
277
+ autowrap_functions (Tuple[Callable, ...]): defaults to `()`,
278
+ Python functions that should be wrapped automatically without
279
+ needing to use fx.wrap(). Backward compatibility for this
280
+ parameter is guaranteed.
281
+
282
+ param_shapes_constant (bool): When this flag is set, calls to shape,
283
+ size and a few other shape like attributes of a module's parameter
284
+ will be evaluated directly, rather than returning a new Proxy value
285
+ for an attribute access. Backward compatibility for this parameter
286
+ is guaranteed.
287
+ """
288
+
289
+ super().__init__()
290
+
291
+ # Functions we will eagerly wrap when we see them while tracing
292
+ # this captures both `math.sqrt()` and `from math import sqrt` automatically
293
+ self._autowrap_function_ids: Set[int] = {
294
+ id(value)
295
+ for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
296
+ if not name.startswith("_") and callable(value)
297
+ }
298
+ self._autowrap_function_ids.update({id(f) for f in autowrap_functions})
299
+
300
+ # Python modules to apply autowrap to at the start, in addition to
301
+ # modules we see while tracing
302
+ self._autowrap_search: List[ModuleType] = list(autowrap_modules)
303
+ self.param_shapes_constant = param_shapes_constant
304
+
305
+ self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None
306
+ self.root_module_name: str = ""
307
+ # Maps the containing module's name to the operator name
308
+ self.scope = Scope("", None)
309
+ # Records the module call stack
310
+ self.module_stack = collections.OrderedDict()
311
+ # Mapping of node name to module scope
312
+ self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
313
+
314
+ _qualname_counter: Dict[str, int] = collections.defaultdict(int)
315
+
316
+ @compatibility(is_backward_compatible=True)
317
+ def get_fresh_qualname(self, prefix: str) -> str:
318
+ """
319
+ Gets a fresh name for a prefix and returns it. This function ensures
320
+ that it will not clash with an existing attribute on the graph.
321
+ """
322
+ # The idea here is that if the module doesn't have this prefix at all we
323
+ # should reset the counter to start from the beginning
324
+ # It's a ... little bit hacky (doesn't cover all cases) but the precise
325
+ # naming of the prefixes isn't a correctness issue, just a niceness
326
+ # issue
327
+ qualname = f"{prefix}0"
328
+ if not hasattr(self.root, qualname):
329
+ self._qualname_counter[prefix] = 0
330
+ return qualname
331
+
332
+ i = self._qualname_counter[prefix]
333
+ while True:
334
+ qualname = f"{prefix}{i}"
335
+ i += 1
336
+ if not hasattr(self.root, qualname):
337
+ break
338
+ self._qualname_counter[prefix] = i
339
+
340
+ return qualname
341
+
342
+ @compatibility(is_backward_compatible=True)
343
+ def create_arg(self, a: Any) -> "Argument":
344
+ """
345
+ A method to specify the behavior of tracing when preparing values to
346
+ be used as arguments to nodes in the ``Graph``.
347
+
348
+ By default, the behavior includes:
349
+
350
+ #. Iterate through collection types (e.g. tuple, list, dict) and recursively
351
+ call ``create_args`` on the elements.
352
+ #. Given a Proxy object, return a reference to the underlying IR ``Node``
353
+ #. Given a non-Proxy Tensor object, emit IR for various cases:
354
+
355
+ * For a Parameter, emit a ``get_attr`` node referring to that Parameter
356
+ * For a non-Parameter Tensor, store the Tensor away in a special
357
+ attribute referring to that attribute.
358
+
359
+ This method can be overridden to support more types.
360
+
361
+ Args:
362
+
363
+ a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.
364
+
365
+
366
+ Returns:
367
+
368
+ The value ``a`` converted into the appropriate ``Argument``
369
+ """
370
+ # The base tracer is used to construct Graphs when there is no associated
371
+ # module hierarchy, so it can never create parameter references.
372
+ # The default tracer adds the ability to refer to parameters when
373
+ # tracing modules.
374
+ if isinstance(a, torch.nn.Parameter):
375
+ for n, p in self.root.named_parameters():
376
+ if a is p:
377
+ return self.create_node("get_attr", n, (), {})
378
+ raise NameError("parameter is not a member of this module")
379
+ elif isinstance(a, torch.Tensor):
380
+ for n_, p_ in self.root.named_buffers():
381
+ if a is p_:
382
+ return self.create_node("get_attr", n_, (), {})
383
+ elif isinstance(a, torch.nn.Module):
384
+ for n_, p_ in self.root.named_modules():
385
+ if a is p_:
386
+ return self.create_node("get_attr", n_, (), {})
387
+ # For NamedTuple instances that appear literally as args, we emit
388
+ # a node to construct the NamedTuple and use that Node as the argument.
389
+ if isinstance(a, tuple) and hasattr(a, "_fields"):
390
+ args = tuple(self.create_arg(elem) for elem in a)
391
+ return self.create_node("call_function", a.__class__, args, {})
392
+
393
+ # Tensors do not have a reliable string repr() from which they can be
394
+ # constructed (and we probably don't want to rely on that, either), so
395
+ # for any constant Tensor values we encounter, first search for if they
396
+ # are an attribute of some module in the module hierarchy. If so, emit
397
+ # a get_attr to retrieve that tensor. Otherwise, we'll store away the
398
+ # tensor value into a special attribute on the Module s.t. we can
399
+ # retrieve it with a get_attr.
400
+ if isinstance(a, (torch.Tensor, ScriptObject, FakeScriptObject)):
401
+ qualname: Optional[str] = self.tensor_attrs.get(a)
402
+
403
+ # Tensor was not found in the Module hierarchy, stow it away in a
404
+ # special attribute and set the qualname to refer to that
405
+ if not qualname:
406
+ base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj"
407
+ qualname = self.get_fresh_qualname(base_name)
408
+ assert isinstance(qualname, str)
409
+ self.tensor_attrs[a] = qualname
410
+ setattr(self.root, qualname, a)
411
+
412
+ return self.create_node("get_attr", qualname, (), {})
413
+
414
+ if type(a) in _proxyable_classes:
415
+ # This is an instance of a proxyable class for which we did not
416
+ # witness its construction. Intern this as a constant attribute
417
+
418
+ # TODO: binary search
419
+ qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_")
420
+ assert isinstance(qualname, str)
421
+ setattr(self.root, qualname, a)
422
+
423
+ return self.create_node("get_attr", qualname, (), {})
424
+
425
+ return super().create_arg(a)
426
+
427
+ @compatibility(is_backward_compatible=True)
428
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
429
+ """
430
+ A method to specify whether a given ``nn.Module`` is a "leaf" module.
431
+
432
+ Leaf modules are the atomic units that appear in
433
+ the IR, referenced by ``call_module`` calls. By default,
434
+ Modules in the PyTorch standard library namespace (torch.nn)
435
+ are leaf modules. All other modules are traced through and
436
+ their constituent ops are recorded, unless specified otherwise
437
+ via this parameter.
438
+
439
+ Args:
440
+
441
+ m (Module): The module being queried about
442
+ module_qualified_name (str): The path to root of this module. For example,
443
+ if you have a module hierarchy where submodule ``foo`` contains
444
+ submodule ``bar``, which contains submodule ``baz``, that module will
445
+ appear with the qualified name ``foo.bar.baz`` here.
446
+ """
447
+ return (
448
+ (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
449
+ and not isinstance(m, torch.nn.Sequential)
450
+ )
451
+
452
+ @compatibility(is_backward_compatible=True)
453
+ def path_of_module(self, mod: torch.nn.Module) -> str:
454
+ """
455
+ Helper method to find the qualified name of ``mod`` in the Module hierarchy
456
+ of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
457
+ a submodule named ``bar``, passing ``bar`` into this function will return
458
+ the string "foo.bar".
459
+
460
+ Args:
461
+
462
+ mod (str): The ``Module`` to retrieve the qualified name for.
463
+ """
464
+ # Prefer the O(1) algorithm
465
+ if self.submodule_paths:
466
+ path = self.submodule_paths.get(mod)
467
+ if path is None:
468
+ raise NameError("module is not installed as a submodule")
469
+ assert isinstance(path, str)
470
+ return path
471
+ # O(N^2) fallback in the case that we didn't store the submodule
472
+ # paths.
473
+ else:
474
+ for n, p in self.root.named_modules():
475
+ if mod is p:
476
+ return n
477
+ raise NameError("module is not installed as a submodule")
478
+
479
+ @compatibility(is_backward_compatible=True)
480
+ def call_module(
481
+ self,
482
+ m: torch.nn.Module,
483
+ forward: Callable[..., Any],
484
+ args: Tuple[Any, ...],
485
+ kwargs: Dict[str, Any],
486
+ ) -> Any:
487
+ """
488
+ Method that specifies the behavior of this ``Tracer`` when it encounters
489
+ a call to an ``nn.Module`` instance.
490
+
491
+ By default, the behavior is to check if the called module is a leaf module
492
+ via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to
493
+ ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through
494
+ the operations in its ``forward`` function.
495
+
496
+ This method can be overridden to--for example--create nested traced
497
+ GraphModules, or any other behavior you would want while tracing across
498
+ ``Module`` boundaries.
499
+
500
+ Args:
501
+
502
+ m (Module): The module for which a call is being emitted
503
+ forward (Callable): The forward() method of the ``Module`` to be invoked
504
+ args (Tuple): args of the module callsite
505
+ kwargs (Dict): kwargs of the module callsite
506
+
507
+ Return:
508
+
509
+ The return value from the Module call. In the case that a ``call_module``
510
+ node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
511
+ value was returned from the ``Module`` invocation.
512
+ """
513
+ module_qualified_name = self.path_of_module(m)
514
+ with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
515
+ # module_stack is an ordered dict so writing then deleting the
516
+ # entry is equivalent to push/pop on a list
517
+ self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type)
518
+ if not self.is_leaf_module(m, module_qualified_name):
519
+ ret_val = forward(*args, **kwargs)
520
+ else:
521
+ ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs)
522
+ key, _ = self.module_stack.popitem(last=True)
523
+ assert key == _scope.module_path, f" Unexpected key {key}"
524
+
525
+ return ret_val
526
+
527
+ @compatibility(is_backward_compatible=False)
528
+ def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
529
+ """
530
+ Method that specifies the behavior of this ``Tracer`` when we call getattr
531
+ on a call to an ``nn.Module`` instance.
532
+
533
+ By default, the behavior is to return a proxy value for the attribute. It
534
+ also stores the proxy value in the ``parameter_proxy_cache``, so that future
535
+ calls will reuse the proxy rather than creating a new one.
536
+
537
+ This method can be overridden to --for example-- not return proxies when
538
+ querying parameters.
539
+
540
+ Args:
541
+
542
+ attr (str): The name of the attribute being queried
543
+ attr_val (Any): The value of the attribute
544
+ parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies
545
+
546
+ Return:
547
+
548
+ The return value from the getattr call.
549
+ """
550
+ def maybe_get_proxy_for_attr(
551
+ attr_val, collection_to_search, parameter_proxy_cache
552
+ ):
553
+ for n, p in collection_to_search:
554
+ if attr_val is p:
555
+ if n not in parameter_proxy_cache:
556
+ kwargs = {}
557
+ if (
558
+ "proxy_factory_fn"
559
+ in inspect.signature(self.create_proxy).parameters
560
+ ):
561
+ kwargs["proxy_factory_fn"] = (
562
+ None
563
+ if not self.param_shapes_constant
564
+ else lambda node: ParameterProxy(
565
+ self, node, n, attr_val
566
+ )
567
+ )
568
+ val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
569
+ parameter_proxy_cache[n] = val_proxy
570
+ return parameter_proxy_cache[n]
571
+ return None
572
+
573
+ if isinstance(attr_val, torch.nn.Parameter):
574
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(
575
+ attr_val, self.root.named_parameters(), parameter_proxy_cache
576
+ )
577
+ if maybe_parameter_proxy is not None:
578
+ return maybe_parameter_proxy
579
+
580
+ if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
581
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(
582
+ attr_val, self.root.named_buffers(), parameter_proxy_cache
583
+ )
584
+ if maybe_buffer_proxy is not None:
585
+ return maybe_buffer_proxy
586
+
587
+ return attr_val
588
+
589
+ # This method will be refactored
590
+ @compatibility(is_backward_compatible=False)
591
+ def create_args_for_root(self, root_fn, is_module, concrete_args=None):
592
+ """
593
+ Create ``placeholder`` nodes corresponding to the signature of the ``root``
594
+ Module. This method introspects root's signature and emits those
595
+ nodes accordingly, also supporting ``*args`` and ``**kwargs``.
596
+ """
597
+ # In some cases, a function or method has been decorated with a wrapper
598
+ # defined via ``functools.wraps``. In this case, the outer code object
599
+ # will likely not contain the actual parameters we care about, so unwrap
600
+ # the function to get to the innermost callable.
601
+ fn_for_analysis = inspect.unwrap(root_fn)
602
+ co = fn_for_analysis.__code__
603
+ total_args = co.co_argcount + co.co_kwonlyargcount
604
+ orig_args = list(co.co_varnames)
605
+ names_iter = iter(co.co_varnames)
606
+ args: List[Any] = []
607
+ skip_arg_idx = 0
608
+ if is_module:
609
+ if total_args == 0:
610
+ raise RuntimeError(
611
+ "``self`` argument cannot be part of *args expansion!"
612
+ )
613
+ skip_arg_idx = 1
614
+ next(names_iter) # skip self
615
+ args.append(self.root)
616
+
617
+ sig = inspect.signature(fn_for_analysis)
618
+
619
+
620
+ # This covers the very specific case where we are passing in flat
621
+ # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
622
+ # In this case, just take the concrete_args and pass them through.
623
+ name_idx = 0
624
+ if isinstance(concrete_args, tuple) and \
625
+ len(concrete_args) > 0 and \
626
+ (co.co_flags & HAS_VARSTUFF) and \
627
+ total_args == 1:
628
+ for concrete_arg in concrete_args:
629
+ out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
630
+ if isinstance(concrete_arg, PHBase):
631
+ if concrete_arg != PH:
632
+ # Transfer attrs in the case where you're using a placeholder other
633
+ # than the singleton PH (PH has no attributes to transfer).
634
+ # Proxies were created out of the placeholders.
635
+ # Transfer any metadata (put on the placeholders in the form of
636
+ # attributes set by the user) from the placeholder to the
637
+ # underlying nodes (the proxy is unwrapped by the user, but
638
+ # the metadata should hold).
639
+ _transfer_attrs(fr=concrete_arg, to=out.node)
640
+ args.append(out)
641
+ name_idx += 1
642
+ return root_fn, args
643
+
644
+ arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
645
+ if isinstance(concrete_args, tuple):
646
+ if len(arg_names) != len(concrete_args):
647
+ raise RuntimeError(
648
+ f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments"
649
+ )
650
+ concrete_args = dict(zip(arg_names, concrete_args))
651
+
652
+ def proxy_placeholder(name):
653
+ return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis)
654
+
655
+ args.extend(proxy_placeholder(names) for names in arg_names)
656
+
657
+ if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF:
658
+ # TODO: type annotations for *args and **kwargs
659
+ if co.co_flags & inspect.CO_VARARGS:
660
+ args.append(proxy_placeholder("*" + next(names_iter)))
661
+ if co.co_flags & inspect.CO_VARKEYWORDS:
662
+ args.append(proxy_placeholder("**" + next(names_iter)))
663
+ root_fn = _patch_function(root_fn, len(args))
664
+
665
+ flat_args, in_spec = pytree.tree_flatten(tuple(args))
666
+ if not all(child.is_leaf() for child in in_spec.children_specs):
667
+ # In the case that we have pytree-flattened inputs in
668
+ # `concrete_args`, generate a flattening wrapper around the
669
+ # original root function and return that.
670
+ self.graph._codegen = _PyTreeCodeGen(
671
+ _PyTreeInfo(orig_args[:total_args], in_spec, None)
672
+ )
673
+
674
+ def flatten_fn(*args):
675
+ tree_args = pytree.tree_unflatten(list(args), in_spec)
676
+ tree_out = root_fn(*tree_args)
677
+ out_args, out_spec = pytree.tree_flatten(tree_out)
678
+ assert isinstance(self.graph._codegen, _PyTreeCodeGen)
679
+ self.graph._codegen.pytree_info = (
680
+ self.graph._codegen.pytree_info._replace(out_spec=out_spec)
681
+ )
682
+ return out_args
683
+
684
+ return flatten_fn, flat_args
685
+ return root_fn, args
686
+
687
+ @compatibility(is_backward_compatible=True)
688
+ def trace(
689
+ self,
690
+ root: Union[torch.nn.Module, Callable[..., Any]],
691
+ concrete_args: Optional[Dict[str, Any]] = None,
692
+ ) -> Graph:
693
+ """
694
+ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
695
+ can either be an ``nn.Module`` instance or a Python callable.
696
+
697
+ Note that after this call, ``self.root`` may be different from the ``root`` passed
698
+ in here. For example, when a free function is passed to ``trace()``, we will
699
+ create an ``nn.Module`` instance to use as the root and add embedded constants
700
+ to.
701
+
702
+
703
+ Args:
704
+
705
+ root (Union[Module, Callable]): Either a ``Module`` or a function to be
706
+ traced through. Backwards-compatibility for this parameter is
707
+ guaranteed.
708
+ concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
709
+ not be treated as Proxies. This parameter is experimental and
710
+ its backwards-compatibility is *NOT* guaranteed.
711
+
712
+ Returns:
713
+
714
+ A ``Graph`` representing the semantics of the passed-in ``root``.
715
+ """
716
+ global _is_fx_tracing_flag
717
+ old_is_fx_tracing_flag = _is_fx_tracing_flag
718
+ _is_fx_tracing_flag = True
719
+ try:
720
+ if isinstance(root, torch.nn.Module):
721
+
722
+ # do real recompilation for _LazyGraphModule before retracing since the trace
723
+ # method can not trace the _lazy_forward method. Got error:
724
+ # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259
725
+ # without this.
726
+ from torch.fx._lazy_graph_module import _LazyGraphModule
727
+ _LazyGraphModule.force_recompile(root)
728
+
729
+ self.root = root
730
+
731
+ assert hasattr(
732
+ type(root), self.traced_func_name
733
+ ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
734
+
735
+ fn = getattr(type(root), self.traced_func_name)
736
+ self.root_module_name = root._get_name()
737
+ self.submodule_paths = {mod: name for name, mod in root.named_modules()}
738
+ else:
739
+ self.root = torch.nn.Module()
740
+ fn = root
741
+
742
+ tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None)
743
+ self.graph = Graph(tracer_cls=tracer_cls)
744
+ if hasattr(fn, '__code__'):
745
+ code = fn.__code__
746
+ self.graph._co_fields = {
747
+ 'co_name': code.co_name,
748
+ 'co_filename': code.co_filename,
749
+ 'co_firstlineno': code.co_firstlineno,
750
+ }
751
+
752
+ # When we encounter a Tensor value that's not a parameter, we look if it
753
+ # is some other attribute on the model. Construct a dict mapping Tensor
754
+ # values to the qualified name here for efficiency. This is used downstream
755
+ # in create_arg
756
+ self.tensor_attrs: Dict[
757
+ Union[
758
+ torch.Tensor,
759
+ ScriptObject,
760
+ FakeScriptObject
761
+ ], str
762
+ ] = {}
763
+
764
+ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]):
765
+ for k, v in m.__dict__.items():
766
+ if isinstance(v, (torch.Tensor, ScriptObject, FakeScriptObject)):
767
+ self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
768
+ for k, v in m.named_children():
769
+ collect_tensor_attrs(v, prefix_atoms + [k])
770
+
771
+ collect_tensor_attrs(self.root, [])
772
+
773
+ assert isinstance(fn, FunctionType)
774
+
775
+ fn_globals = fn.__globals__ # run before it gets patched
776
+ fn, args = self.create_args_for_root(
777
+ fn, isinstance(root, torch.nn.Module), concrete_args
778
+ )
779
+
780
+ parameter_proxy_cache: Dict[
781
+ str, Proxy
782
+ ] = {} # Reduce number of get_attr calls
783
+
784
+ # Method dispatch on parameters is not recorded unless it's directly used.
785
+ # Thus, we need to insert a proxy when __getattr__ requests a parameter.
786
+ @functools.wraps(_orig_module_getattr)
787
+ def module_getattr_wrapper(mod, attr):
788
+ attr_val = _orig_module_getattr(mod, attr)
789
+ return self.getattr(attr, attr_val, parameter_proxy_cache)
790
+
791
+ @functools.wraps(_orig_module_call)
792
+ def module_call_wrapper(mod, *args, **kwargs):
793
+ def forward(*args, **kwargs):
794
+ return _orig_module_call(mod, *args, **kwargs)
795
+
796
+ _autowrap_check(
797
+ patcher, # type: ignore[has-type]
798
+ getattr(getattr(mod, "forward", mod), "__globals__", {}),
799
+ self._autowrap_function_ids,
800
+ )
801
+ return self.call_module(mod, forward, args, kwargs)
802
+
803
+ with _new_patcher() as patcher:
804
+ # allow duplicate patches to support the case of nested calls
805
+ patcher.patch_method(
806
+ torch.nn.Module,
807
+ "__getattr__",
808
+ module_getattr_wrapper,
809
+ deduplicate=False,
810
+ )
811
+ patcher.patch_method(
812
+ torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False
813
+ )
814
+ _patch_wrapped_functions(patcher)
815
+ _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
816
+ for module in self._autowrap_search:
817
+ _autowrap_check(
818
+ patcher, module.__dict__, self._autowrap_function_ids
819
+ )
820
+ self.create_node(
821
+ "output",
822
+ "output",
823
+ (self.create_arg(fn(*args)),),
824
+ {},
825
+ type_expr=fn.__annotations__.get("return", None),
826
+ )
827
+
828
+ self.submodule_paths = None
829
+ finally:
830
+ _is_fx_tracing_flag = old_is_fx_tracing_flag
831
+ return self.graph
832
+
833
+ def __deepcopy__(self, memo):
834
+ # _autowrap_search contains modules, which cannot be deepcopied.
835
+ new_tracer = Tracer.__new__(Tracer)
836
+
837
+ for k, v in self.__dict__.items():
838
+ if k in {'_autowrap_search'}:
839
+ new_obj = copy.copy(v)
840
+ else:
841
+ new_obj = copy.deepcopy(v, memo)
842
+
843
+ new_tracer.__dict__[k] = new_obj
844
+
845
+ return new_tracer
846
+
847
+ def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis):
848
+ if concrete_args is not None and name in concrete_args:
849
+ cnt = 0
850
+
851
+ def replace_ph(x):
852
+ nonlocal cnt
853
+ cnt += 1
854
+ param = sig.parameters[name]
855
+ default = (
856
+ ()
857
+ if param.default is inspect.Parameter.empty
858
+ else (param.default,)
859
+ )
860
+ out = self.create_proxy(
861
+ "placeholder", f"{name}_{str(cnt)}", default, {}
862
+ )
863
+ if isinstance(x, PHBase):
864
+ if x != PH:
865
+ # Transfer attrs in the case where you're using a placeholder other
866
+ # than the singleton PH (PH has no attributes to transfer).
867
+ # Proxies were created out of the placeholders.
868
+ # Transfer any metadata (put on the placeholders in the form of
869
+ # attributes set by the user) from the placeholder to the
870
+ # underlying nodes (the proxy is unwrapped by the user, but
871
+ # the metadata should hold).
872
+ _transfer_attrs(fr=x, to=out.node)
873
+
874
+ return out
875
+ # Union[int, bool] == bool in Python <= 3.6
876
+ if (
877
+ type(x) == bool
878
+ or type(x) in base_types
879
+ and type(x) != torch.Tensor
880
+ ):
881
+ torch._assert(
882
+ out == x,
883
+ f"{name} has been specialized to have value {x} but got another value",
884
+ )
885
+ elif x is None:
886
+ args = (
887
+ out,
888
+ f"{name} has been specialized to have value None but got another value",
889
+ )
890
+ self.create_proxy("call_function", _assert_is_none, args, {})
891
+ else:
892
+ warnings.warn(
893
+ f"Was not able to add assertion to guarantee correct input {name} to "
894
+ f"specialized function. It is up to the user to make sure that your inputs match the "
895
+ f"inputs you specialized the function with."
896
+ )
897
+
898
+ return x
899
+
900
+ return pytree.tree_map(replace_ph, concrete_args[name])
901
+ if name[0] == "*":
902
+ default = ()
903
+ else:
904
+ param = sig.parameters[name]
905
+ default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment]
906
+ return self.create_proxy(
907
+ "placeholder",
908
+ name,
909
+ default,
910
+ {},
911
+ type_expr=fn_for_analysis.__annotations__.get(name, None)
912
+ )
913
+
914
+
915
+ # Dictionary of (id(globals dict), function name) => globals_dict to patch for
916
+ # the purposes of the wrap() API.
917
+ # We key by the globals dict id and function name to ensure we're wrapping a given
918
+ # function only once.
919
+ _wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {}
920
+
921
+ # List of methods on classes to wrap (class type, function name)
922
+ # this currently only works for Tensor.* methods that aren't traced properly
923
+ _wrapped_methods_to_patch: List[Tuple[type, str]] = []
924
+
925
+ if os.environ.get("FX_PATCH_GETITEM") == "1":
926
+ # This change is needed to trace models like PositionalEmbedding from BERT:
927
+ # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py
928
+ # but causes issues in quantization documented here:
929
+ # https://github.com/pytorch/pytorch/issues/50710
930
+ # once that is fixed we can make this the default behavior.
931
+ _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
932
+
933
+
934
+ def _find_proxy(*objects_to_search):
935
+ """
936
+ Recursively search a data structure for a Proxy() and return it,
937
+ return None if not found.
938
+ """
939
+ proxy = None
940
+
941
+ def find_proxy(x):
942
+ nonlocal proxy
943
+ if isinstance(x, Proxy):
944
+ proxy = x
945
+
946
+ map_aggregate(objects_to_search, find_proxy)
947
+ return proxy
948
+
949
+
950
+ def _create_wrapped_func(orig_fn):
951
+ @functools.wraps(orig_fn)
952
+ def wrapped(*args, **kwargs):
953
+ """
954
+ Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
955
+ a Proxy object. If there is one, emit a ``call_function`` node to preserve the
956
+ call to this leaf function directly. Otherwise, just return the results of
957
+ this function call, as this function is not being traced.
958
+ """
959
+ proxy = _find_proxy(args, kwargs)
960
+ if proxy is not None:
961
+ return_proxy = proxy.tracer.create_proxy(
962
+ "call_function", orig_fn, args, kwargs
963
+ )
964
+ return_proxy.node.meta["is_wrapped"] = True
965
+ return return_proxy
966
+ return orig_fn(*args, **kwargs)
967
+
968
+ return wrapped
969
+
970
+
971
+ def _create_wrapped_method(cls, name):
972
+ orig_fn = getattr(cls, name)
973
+
974
+ @functools.wraps(orig_fn)
975
+ def wrapped(*args, **kwargs):
976
+ """
977
+ Search the args and kwargs for a Proxy object. If there is one,
978
+ emit a ``call_method`` node to preserve the call to this method
979
+ directly. Otherwise, just return the results of this function
980
+ call, as this function is not being traced.
981
+ """
982
+ proxy = _find_proxy(args, kwargs)
983
+ if proxy is not None:
984
+ return proxy.tracer.create_proxy("call_method", name, args, kwargs)
985
+ return orig_fn(*args, **kwargs)
986
+
987
+ return wrapped
988
+
989
+
990
+ class _PatchedFn(NamedTuple):
991
+ frame_dict: Any
992
+ fn_name: str
993
+ orig_fn: Any
994
+ new_fn: Any
995
+
996
+ def revert(self):
997
+ raise NotImplementedError
998
+
999
+ def patch(self):
1000
+ raise NotImplementedError
1001
+
1002
+
1003
+ class _PatchedFnSetItem(_PatchedFn):
1004
+ def revert(self):
1005
+ self.frame_dict[self.fn_name] = self.orig_fn
1006
+
1007
+ def patch(self):
1008
+ self.frame_dict[self.fn_name] = self.new_fn
1009
+
1010
+ class _PatchedFnDel(_PatchedFn):
1011
+ def revert(self):
1012
+ del self.frame_dict[self.fn_name]
1013
+
1014
+ def patch(self):
1015
+ self.frame_dict[self.fn_name] = self.new_fn
1016
+
1017
+
1018
+ class _PatchedFnSetAttr(_PatchedFn):
1019
+ def revert(self):
1020
+ setattr(self.frame_dict, self.fn_name, self.orig_fn)
1021
+
1022
+ def patch(self):
1023
+ setattr(self.frame_dict, self.fn_name, self.new_fn)
1024
+
1025
+ class _Patcher:
1026
+ def __init__(self) -> None:
1027
+ super().__init__()
1028
+ self.patches_made: List[_PatchedFn] = []
1029
+ self.visited: Set[int] = set()
1030
+
1031
+ def patch(
1032
+ self,
1033
+ frame_dict: Dict[str, Any],
1034
+ name: str,
1035
+ new_fn: Callable,
1036
+ deduplicate: bool = True,
1037
+ ):
1038
+ """
1039
+ Replace frame_dict[name] with new_fn until we exit the context manager.
1040
+ """
1041
+ new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
1042
+ if name not in frame_dict and hasattr(builtins, name):
1043
+ self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn))
1044
+ self.patches_made[-1].patch()
1045
+ elif getattr(frame_dict[name], "__fx_already_patched", False):
1046
+ return # already patched, no need to do it again
1047
+ else:
1048
+ self.patches_made.append(
1049
+ _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn)
1050
+ )
1051
+ self.patches_made[-1].patch()
1052
+
1053
+ def patch_method(
1054
+ self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True
1055
+ ):
1056
+ """
1057
+ Replace object_or_dict.name with new_fn until we exit the context manager.
1058
+ """
1059
+ new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined]
1060
+ orig_fn = getattr(cls, name)
1061
+ if getattr(orig_fn, "__fx_already_patched", False):
1062
+ return # already patched, no need to do it again
1063
+ self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn))
1064
+ self.patches_made[-1].patch()
1065
+
1066
+ def visit_once(self, thing: Any):
1067
+ """Return True on the first call to with thing, otherwise false"""
1068
+ idx = id(thing)
1069
+ if idx in self.visited:
1070
+ return False
1071
+ self.visited.add(idx)
1072
+ return True
1073
+
1074
+ def revert_all_patches(self):
1075
+ """
1076
+ Remove all the stored patcheds. It doesn't modify patches_made.
1077
+ """
1078
+ for patch in self.patches_made:
1079
+ patch.revert()
1080
+ return self.patches_made
1081
+
1082
+ def reapply_all_patches(self):
1083
+ """
1084
+ Patch all the stored patcheds. It doesn't modify patches_made.
1085
+ """
1086
+ for patch in self.patches_made:
1087
+ patch.patch()
1088
+ return self.patches_made
1089
+
1090
+ def __enter__(self):
1091
+ return self
1092
+
1093
+ def __exit__(self, exc_type, exc_val, exc_tb):
1094
+ """
1095
+ Undo all the changes made via self.patch() and self.patch_method()
1096
+ """
1097
+ while self.patches_made:
1098
+ # unpatch in reverse order to handle duplicates correctly
1099
+ self.patches_made.pop().revert()
1100
+ self.visited.clear()
1101
+
1102
+
1103
+ CURRENT_PATCHER: Optional[_Patcher] = None
1104
+
1105
+ @contextlib.contextmanager
1106
+ def _new_patcher():
1107
+ global CURRENT_PATCHER
1108
+ prior_patcher = CURRENT_PATCHER
1109
+ try:
1110
+ CURRENT_PATCHER = _Patcher()
1111
+ yield CURRENT_PATCHER
1112
+ finally:
1113
+ # Clear all the patches made by when using current patcher.
1114
+ assert CURRENT_PATCHER is not None
1115
+ CURRENT_PATCHER.revert_all_patches()
1116
+ CURRENT_PATCHER = prior_patcher
1117
+
1118
+
1119
+ @contextlib.contextmanager
1120
+ def _maybe_revert_all_patches():
1121
+ current_patcher = CURRENT_PATCHER
1122
+ patches_made = None
1123
+ patches_removed = None
1124
+ try:
1125
+ if current_patcher is not None:
1126
+ patches_removed = current_patcher.revert_all_patches()
1127
+ yield
1128
+ finally:
1129
+ if current_patcher is not None:
1130
+ patches_made = current_patcher.reapply_all_patches()
1131
+ assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches"
1132
+
1133
+ def _patch_wrapped_functions(patcher: _Patcher):
1134
+ """
1135
+ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
1136
+ the listed global functions in the `_create_wrapped_func` wrapper.
1137
+ """
1138
+ for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items():
1139
+ if name not in frame_dict and hasattr(builtins, name):
1140
+ orig_fn = getattr(builtins, name)
1141
+ else:
1142
+ orig_fn = frame_dict[name]
1143
+ patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
1144
+
1145
+ for cls, name in _wrapped_methods_to_patch:
1146
+ patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
1147
+
1148
+
1149
+ def _autowrap_check(
1150
+ patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int]
1151
+ ):
1152
+ """
1153
+ Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
1154
+ This method searches a scope for them and patches them if found.
1155
+ """
1156
+ if patcher.visit_once(frame_dict):
1157
+ for name, value in frame_dict.items():
1158
+ if (
1159
+ not name.startswith("_")
1160
+ and callable(value)
1161
+ and id(value) in function_ids
1162
+ ):
1163
+ patcher.patch(frame_dict, name, _create_wrapped_func(value))
1164
+
1165
+
1166
+ @compatibility(is_backward_compatible=True)
1167
+ def wrap(fn_or_name: Union[str, Callable]):
1168
+ """
1169
+ This function can be called at module-level scope to register fn_or_name as a "leaf function".
1170
+ A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being
1171
+ traced through::
1172
+
1173
+ # foo/bar/baz.py
1174
+ def my_custom_function(x, y):
1175
+ return x * x + y * y
1176
+
1177
+ torch.fx.wrap('my_custom_function')
1178
+
1179
+ def fn_to_be_traced(x, y):
1180
+ # When symbolic tracing, the below call to my_custom_function will be inserted into
1181
+ # the graph rather than tracing it.
1182
+ return my_custom_function(x, y)
1183
+
1184
+ This function can also equivalently be used as a decorator::
1185
+
1186
+ # foo/bar/baz.py
1187
+ @torch.fx.wrap
1188
+ def my_custom_function(x, y):
1189
+ return x * x + y * y
1190
+
1191
+ A wrapped function can be thought of a "leaf function", analogous to the concept of
1192
+ "leaf modules", that is, they are functions that are left as calls in the FX trace
1193
+ rather than traced through.
1194
+
1195
+ Args:
1196
+
1197
+ fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
1198
+ graph when it's called
1199
+ """
1200
+ if not callable(fn_or_name) and not isinstance(fn_or_name, str):
1201
+ raise RuntimeError(
1202
+ "Unsupported type for global function! Must be either a callable or "
1203
+ "string name"
1204
+ )
1205
+
1206
+ if callable(fn_or_name):
1207
+ assert not isinstance(fn_or_name, str) # to make mypy happy
1208
+ fn_name = fn_or_name.__name__
1209
+ else:
1210
+ assert isinstance(
1211
+ fn_or_name, str
1212
+ ), "fn_or_name must be a global function or string name"
1213
+ fn_name = fn_or_name
1214
+
1215
+ currentframe = inspect.currentframe()
1216
+ assert currentframe is not None
1217
+ f = currentframe.f_back
1218
+ assert f is not None
1219
+ if f.f_code.co_name != "<module>":
1220
+ raise NotImplementedError("wrap must be called at the top level of a module")
1221
+
1222
+ # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
1223
+ # semantics would be slightly different, but would add support `from x import wrapped_function`
1224
+ _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals
1225
+ return fn_or_name
1226
+
1227
+
1228
+ @compatibility(is_backward_compatible=True)
1229
+ def symbolic_trace(
1230
+ root: Union[torch.nn.Module, Callable[..., Any]],
1231
+ concrete_args: Optional[Dict[str, Any]] = None,
1232
+ ) -> GraphModule:
1233
+ """
1234
+ Symbolic tracing API
1235
+
1236
+ Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
1237
+ constructed by recording operations seen while tracing through ``root``.
1238
+
1239
+ ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures.
1240
+
1241
+ For example::
1242
+
1243
+ def f(a, b):
1244
+ if b == True:
1245
+ return a
1246
+ else:
1247
+ return a*2
1248
+
1249
+ FX can typically not trace through this due to the presence of control
1250
+ flow. However, we can use `concrete_args` to specialize on the value of
1251
+ `b` to trace through this::
1252
+
1253
+ f = fx.symbolic_trace(f, concrete_args={'b': False})
1254
+ assert f(3, False) == 6
1255
+
1256
+ Note that although you can still pass in different values of `b`, they will be ignored.
1257
+
1258
+ We can also use `concrete_args` to eliminate data-structure handling from
1259
+ our function. This will use pytrees to flatten your input. To avoid
1260
+ overspecializing, pass in `fx.PH` for values that shouldn't be
1261
+ specialized. For example::
1262
+
1263
+ def f(x):
1264
+ out = 0
1265
+ for v in x.values():
1266
+ out += v
1267
+ return out
1268
+ f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
1269
+ assert f({'a': 1, 'b': 2, 'c': 4}) == 7
1270
+
1271
+
1272
+ Args:
1273
+ root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted
1274
+ into a Graph representation.
1275
+ concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized
1276
+
1277
+ Returns:
1278
+ GraphModule: a Module created from the recorded operations from ``root``.
1279
+ """
1280
+ tracer = Tracer()
1281
+ graph = tracer.trace(root, concrete_args)
1282
+ name = (
1283
+ root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
1284
+ )
1285
+ return _make_graph_module(tracer.root, graph, name)
1286
+
1287
+
1288
+ @wrap
1289
+ def _assert_is_none(value, msg):
1290
+ assert value is None, msg
.venv/lib/python3.11/site-packages/torch/fx/_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import sys
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ from torch._logging import LazyString
7
+
8
+
9
+ def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
10
+ """
11
+ Returns a LazyString that formats the graph code.
12
+ """
13
+
14
+ def format_name():
15
+ if maybe_id is not None:
16
+ return f"{name} {maybe_id}"
17
+ else:
18
+ return name
19
+
20
+ if "print_output" not in kwargs:
21
+ kwargs["print_output"] = False
22
+
23
+ if "colored" in kwargs and not sys.stdout.isatty():
24
+ kwargs["colored"] = False
25
+
26
+ return LazyString(
27
+ lambda: _format_graph_code(
28
+ f"===== {format_name()} =====\n",
29
+ gm.forward.__code__.co_filename,
30
+ gm.print_readable(**kwargs),
31
+ )
32
+ )
33
+
34
+
35
+ def _format_graph_code(name, filename, graph_str):
36
+ """
37
+ Returns a string that formats the graph code.
38
+ """
39
+ return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
40
+
41
+
42
+ def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
43
+ """
44
+ Returns the nn_module_stack of the first call_function node.
45
+ """
46
+ for node in graph.nodes:
47
+ if node.op == "call_function" and "nn_module_stack" in node.meta:
48
+ return node.meta["nn_module_stack"]
49
+ return None
50
+
51
+
52
+ def get_node_context(node, num_nodes=2) -> str:
53
+ """
54
+ Returns a string of the last num_nodes nodes in the graph.
55
+ """
56
+ node_contexts = []
57
+ cur = node
58
+ for i in range(num_nodes):
59
+ node_contexts.append(cur.format_node())
60
+ if cur.op == "root":
61
+ break
62
+ cur = cur.prev
63
+ return "\n".join(node_contexts[::-1])
.venv/lib/python3.11/site-packages/torch/fx/annotate.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch.fx.proxy import Proxy
3
+ from ._compatibility import compatibility
4
+
5
+ @compatibility(is_backward_compatible=False)
6
+ def annotate(val, type):
7
+ """
8
+ Annotates a Proxy object with a given type.
9
+
10
+ This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object
11
+ Args:
12
+ val (object): An object to be annotated if its type is torch.fx.Proxy.
13
+ type (object): A type to be assigned to a given proxy object as val.
14
+ Returns:
15
+ The given val.
16
+ Raises:
17
+ RuntimeError: If a val already has a type in its node.
18
+ """
19
+ if isinstance(val, Proxy):
20
+ if val.node.type:
21
+ raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
22
+ f" Existing type is {val.node.type} "
23
+ f"and new type is {type}. "
24
+ f"This could happen if you tried to annotate a function parameter "
25
+ f"value (in which case you should use the type slot "
26
+ f"on the function signature) or you called "
27
+ f"annotate on the same value twice")
28
+ else:
29
+ val.node.type = type
30
+ return val
31
+ else:
32
+ return val
.venv/lib/python3.11/site-packages/torch/fx/config.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Whether to disable showing progress on compilation passes
2
+ # Need to add a new config otherwise wil get a circular import if dynamo config is imported here
3
+ disable_progress = True
4
+
5
+ # If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
6
+ verbose_progress = False
.venv/lib/python3.11/site-packages/torch/fx/experimental/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-311.pyc ADDED
Binary file (1.47 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/_config.cpython-311.pyc ADDED
Binary file (1.95 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-311.pyc ADDED
Binary file (47.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/debug.cpython-311.pyc ADDED
Binary file (1.69 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-311.pyc ADDED
Binary file (49.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-311.pyc ADDED
Binary file (7.27 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-311.pyc ADDED
Binary file (8.32 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-311.pyc ADDED
Binary file (26.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/recording.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-311.pyc ADDED
Binary file (8.31 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-311.pyc ADDED
Binary file (6.99 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-311.pyc ADDED
Binary file (61.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-311.pyc ADDED
Binary file (5.06 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/validator.cpython-311.pyc ADDED
Binary file (41.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/_backward_state.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.fx
2
+
3
+
4
+ class BackwardState:
5
+ """
6
+ BackwardState is used to pass Python hooks from the forwards pass
7
+ into the backwards pass in Dynamo+Compiled Autograd.
8
+
9
+ It is created by TorchDynamo and has special handling there.
10
+ Dynamo will pass an empty BackwardState to the forwards, then populate
11
+ members on it (via setattr) only after the forwards graph is finished.
12
+ Later on, in CompileAutograd we will inline and add the needed guards
13
+ on the BackwardState.
14
+
15
+ BackwardState is identified and has special handling in AOTAutograd.
16
+ During AOTAutograd:
17
+ 1) BackwardState is an input to the forwards graph
18
+ 2) It must only be used in the backwards
19
+ 3) It will be empty in the forwards
20
+ 4) In the forwards we add a wrapper to save it
21
+ 5) In the backwards it becomes an input
22
+ 6) There can only be one per graph
23
+
24
+ BackwardState requires CompiledAutograd.
25
+ """
26
+
27
+ proxy: torch.fx.Proxy
.venv/lib/python3.11/site-packages/torch/fx/experimental/_config.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+
5
+
6
+ # [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations.
7
+ translation_validation = (
8
+ os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1"
9
+ )
10
+ # Timeout (in milliseconds) for z3 finding a solution.
11
+ # [@compile_ignored: debug]
12
+ translation_validation_timeout = int(
13
+ os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000")
14
+ )
15
+ # Disables bisection for translation validation.
16
+ #
17
+ # Translation validation bisection is enabled by default, if translation validation
18
+ # is also enabled. This should help finding guard simplification issues. However,
19
+ # since validation uses Z3 for bisecting, it might take a lot of time.
20
+ #
21
+ # Set this configuration option so as to avoid bisecting.
22
+ # [@compile_ignored: debug]
23
+ translation_validation_no_bisect = (
24
+ os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1"
25
+ )
26
+ # Checks whether replaying ShapeEnv events on a freshly constructed one yields
27
+ # the a ShapeEnv with the same state. This should be used only in testing.
28
+ check_shape_env_recorded_events = False
29
+
30
+ # TODO: Perhaps consider allowing unions for the configs below (so you can hit
31
+ # multiple reps at the same time)
32
+
33
+ # Give extended debug information if the string representation of a guard
34
+ # matches this. For example, set this to "Ne(s0, 10)" and whenever we issue
35
+ # this guard, we will generate full Python and C++ backtrace
36
+ # [@compile_ignored: debug]
37
+ extended_debug_guard_added = os.environ.get(
38
+ "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None
39
+ )
40
+
41
+ # Give extended debug information when a particular symbol is allocated. For
42
+ # example, set this to "u2" and whenever we create this symbol, we will
43
+ # generate full Python and C++ backtrace
44
+ # [@compile_ignored: debug]
45
+ extended_debug_create_symbol = os.environ.get(
46
+ "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None
47
+ )
48
+
49
+ # Give extended debug information (C++ backtrace) for all extended debug
50
+ # settings as well as errors. The C++ backtrace is slow and very spammy so we
51
+ # don't include it by default even when you're requesting extended debug.
52
+ # [@compile_ignored: debug]
53
+ extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != ""
54
+
55
+ # Give extended debug information (line of code) when a torch function
56
+ # is called during export. This is useful for showing progress and detecting
57
+ # where export might be stuck. Currently only works for strict=False.
58
+ # [@compile_ignored: debug]
59
+ extended_debug_current_loc = (
60
+ os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1"
61
+ )
62
+
63
+ # [@compile_ignored: debug] Show a warning for every specialization
64
+ print_specializations = False
65
+
66
+ # wraps (un)equalities with 'Not' class after recording the correct expression
67
+ # in the FX graph. This should incorrectly construct the divisible and replacement
68
+ # lists, and incorrectly issue guards.
69
+ inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
70
+
71
+ # [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
72
+ validate_shape_env_version_key = False
73
+
74
+ # If we produce more than this many guards on a symbol, force the symbol to
75
+ # get specialized and bail out if this many guards mention this particular
76
+ # symbol. This may be slightly more aggressive than the true number of guards
77
+ # issued (as we test if we've hit the limit on-the-fly, whereas we may
78
+ # do further simplifications at final guard issuance time that make guards
79
+ # irrelevant.)
80
+ symbol_guard_limit_before_specialize: Optional[int] = None
81
+
82
+ # This flag changes whether we should use the same symbolic variable to represent input sizes that are the same.
83
+ use_duck_shape = True
84
+
85
+ from torch.utils._config_module import install_config_module
86
+
87
+
88
+ install_config_module(sys.modules[__name__])