koichi12 commited on
Commit
5dbc224
·
verified ·
1 Parent(s): 9a8eae1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 +3 -0
  3. .venv/lib/python3.11/site-packages/torch/_export/converter.py +1584 -0
  4. .venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py +523 -0
  5. .venv/lib/python3.11/site-packages/torch/_export/pass_base.py +441 -0
  6. .venv/lib/python3.11/site-packages/torch/_export/tools.py +146 -0
  7. .venv/lib/python3.11/site-packages/torch/_export/verifier.py +456 -0
  8. .venv/lib/python3.11/site-packages/torch/_export/wrappers.py +121 -0
  9. .venv/lib/python3.11/site-packages/torch/_lazy/__init__.py +55 -0
  10. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/torch/_lazy/computation.py +27 -0
  13. .venv/lib/python3.11/site-packages/torch/_lazy/config.py +17 -0
  14. .venv/lib/python3.11/site-packages/torch/_lazy/debug.py +22 -0
  15. .venv/lib/python3.11/site-packages/torch/_lazy/device_context.py +26 -0
  16. .venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py +225 -0
  17. .venv/lib/python3.11/site-packages/torch/_lazy/metrics.py +22 -0
  18. .venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py +7 -0
  19. .venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py +35 -0
  20. .venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py +52 -0
  21. .venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py +43 -0
  22. .venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py +647 -0
  23. .venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py +328 -0
  24. .venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py +9 -0
  26. .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  31. .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc +0 -0
  32. .venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py +39 -0
  34. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py +54 -0
.gitattributes CHANGED
@@ -123,3 +123,5 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_
123
  .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
124
  .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
125
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
 
 
 
123
  .venv/lib/python3.11/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 filter=lfs diff=lfs merge=lfs -text
124
  .venv/lib/python3.11/site-packages/nvidia/cublas/lib/libnvblas.so.12 filter=lfs diff=lfs merge=lfs -text
125
  .venv/lib/python3.11/site-packages/opencv_python_headless.libs/libopenblas-r0-f650aae0.3.3.so filter=lfs diff=lfs merge=lfs -text
126
+ .venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 filter=lfs diff=lfs merge=lfs -text
127
+ .venv/lib/python3.11/site-packages/vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/libcudnn_heuristic.so.9 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94fab98c15040558c3c80f2c1a2f5fda9baa72afc39a88bdcc82185f49d241c3
3
+ size 86326864
.venv/lib/python3.11/site-packages/torch/_export/converter.py ADDED
@@ -0,0 +1,1584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import builtins
3
+ import logging
4
+ import operator
5
+ import typing
6
+ import warnings
7
+ from contextlib import contextmanager
8
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
9
+
10
+ import torch
11
+ import torch.export._trace
12
+ from torch import _C
13
+ from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import (
14
+ replace_quantized_ops_with_standard_ops,
15
+ )
16
+ from torch.export.exported_program import ExportedProgram
17
+ from torch.export.graph_signature import (
18
+ ConstantArgument,
19
+ CustomObjArgument,
20
+ InputKind,
21
+ InputSpec,
22
+ OutputKind,
23
+ OutputSpec,
24
+ TensorArgument,
25
+ )
26
+ from torch.fx import subgraph_rewriter
27
+
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ def _get_param_count_list(method_graph, args_params):
33
+ param_count_list = []
34
+ for input_, arg_params_ in zip(method_graph.inputs(), args_params):
35
+ if "PackedParams" in str(input_.type()):
36
+ in_vars, _ = torch.jit._flatten(arg_params_)
37
+ param_count_list.append(len(in_vars))
38
+ else:
39
+ param_count_list.append(arg_params_ is not None)
40
+
41
+ return param_count_list
42
+
43
+
44
+ def _trace_and_get_graph_from_model(model, args):
45
+ # A basic sanity check: make sure the state_dict keys are the same
46
+ # before and after running the model. Fail fast!
47
+ orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
48
+
49
+ # Disable Autocast cache because it replaces kernel's weight and bias
50
+ # by (undesired) constants.
51
+ # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
52
+ prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
53
+ torch.set_autocast_cache_enabled(False)
54
+ trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
55
+ model,
56
+ args,
57
+ strict=False,
58
+ _force_outplace=False,
59
+ _return_inputs_states=True,
60
+ )
61
+ torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
62
+
63
+ if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
64
+ raise RuntimeError(
65
+ "state_dict changed after running the tracer; "
66
+ "something weird is happening in your model!"
67
+ )
68
+
69
+ return trace_graph, torch_out
70
+
71
+
72
+ def _create_jit_graph(
73
+ model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
74
+ ) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
75
+ if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
76
+ flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
77
+ torch_out = None
78
+
79
+ if isinstance(model, torch.jit.ScriptModule):
80
+ try:
81
+ graph = model.forward.graph # type: ignore[attr-defined]
82
+ except AttributeError as e:
83
+ raise RuntimeError("'forward' method must be a script method") from e
84
+ _C._jit_pass_onnx_function_substitution(graph)
85
+ freezed_module = _C._freeze_module(
86
+ typing.cast(_C.ScriptModule, model._c), preserveParameters=True
87
+ )
88
+ module, params = _C._jit_onnx_list_model_parameters(freezed_module)
89
+ method_graph = module._get_method("forward").graph
90
+ args_params = tuple(args) + tuple(params)
91
+ param_count_list = _get_param_count_list(method_graph, args_params)
92
+ in_vars, _ = torch.jit._flatten(args_params)
93
+ graph = _C._propagate_and_assign_input_shapes(
94
+ method_graph, tuple(in_vars), param_count_list, False, False
95
+ )
96
+ return graph, params, torch_out, module
97
+
98
+ # torch.jit.ScriptFunction
99
+ params = []
100
+ graph = model.graph
101
+ _C._jit_pass_onnx_function_substitution(graph)
102
+ param_count_list = _get_param_count_list(graph, args)
103
+ graph = _C._propagate_and_assign_input_shapes(
104
+ graph, flattened_args, param_count_list, False, False
105
+ )
106
+ return graph, params, torch_out, None
107
+
108
+ graph, torch_out = _trace_and_get_graph_from_model(model, args)
109
+ _C._jit_pass_onnx_lint(graph)
110
+ state_dict = torch.jit._unique_state_dict(model)
111
+ params = list(state_dict.values())
112
+ graph_inputs = list(graph.inputs())
113
+ user_input_num = len(graph_inputs) - len(state_dict)
114
+ param_names = list(state_dict.keys())
115
+ for i, inp in enumerate(graph_inputs):
116
+ if i >= user_input_num:
117
+ inp.setDebugName(param_names[i - user_input_num])
118
+ _C._jit_pass_onnx_function_substitution(graph)
119
+ return graph, params, torch_out, None
120
+
121
+
122
+ def list_add(a, b):
123
+ return a + b
124
+
125
+
126
+ def list_append(container, element):
127
+ return container + [element]
128
+
129
+
130
+ def execute_subgraph_from_prim_loop(
131
+ subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs
132
+ ):
133
+ """
134
+ subgraph: GraphModule from sub-block.
135
+ iter_idx: The index of interation.
136
+ len_loop_local_arguments: The number of loop local arguments in args.
137
+ """
138
+
139
+ # Loop local variables. TS graph create those as inputs because their values
140
+ # are updated inside the loop.
141
+ loop_local_args = args[:len_loop_local_arguments]
142
+ # Global variables that are not passed in as inputs to the loop sub-blocks
143
+ # but are directly used. Most of time, their values are not updated, but
144
+ # the only exception is when there are some operations that perform inplace
145
+ # updates.
146
+ global_args = args[len_loop_local_arguments:]
147
+ return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs)
148
+
149
+
150
+ def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
151
+ def pattern(im, dim, scale):
152
+ sym_size_int = torch.ops.aten.sym_size.int(im, dim)
153
+ scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
154
+ div_scalar_mode = torch.ops.aten.div.Scalar_mode(
155
+ scalar_tensor, scale, rounding_mode="trunc"
156
+ )
157
+ int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
158
+ return int_tensor
159
+
160
+ def replacement(im, dim, scale):
161
+ sym_size_int = torch.ops.aten.sym_size.int(im, dim)
162
+ return sym_size_int // scale
163
+
164
+ replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
165
+
166
+
167
+ def is_valid_for_codegen(name):
168
+ if len(name) == 0:
169
+ raise RuntimeError("Empty argument name for codegen")
170
+ if name[0].isdigit():
171
+ return False
172
+ return True
173
+
174
+
175
+ def normalize_name(name: str, prefix: str = "rename") -> str:
176
+ name = name.replace(".", "_")
177
+ if is_valid_for_codegen(name):
178
+ return name
179
+ return f"{prefix}_{name}"
180
+
181
+
182
+ def ir_name_to_func_name(name: str) -> str:
183
+ """prim::If -> convert_prim_If"""
184
+ name_list = name.split("::")
185
+ return "convert_" + "_".join(name_list)
186
+
187
+
188
+ def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph):
189
+ if is_top_level_graph:
190
+ return fx_graph.get_attr(name)
191
+ return fx_graph.placeholder(name)
192
+
193
+
194
+ _TORCH_DTYPE_TO_ENUM = {
195
+ torch.uint8: 0,
196
+ torch.int8: 1,
197
+ torch.int16: 2,
198
+ torch.int32: 3,
199
+ torch.int64: 4,
200
+ torch.float16: 5,
201
+ torch.float32: 6,
202
+ torch.float64: 7,
203
+ torch.complex32: 8,
204
+ torch.complex64: 9,
205
+ torch.complex128: 10,
206
+ torch.bool: 11,
207
+ torch.qint8: 12,
208
+ torch.quint8: 13,
209
+ torch.bfloat16: 15,
210
+ }
211
+
212
+ _TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()}
213
+
214
+
215
+ def get_dtype_as_int(tensor):
216
+ """
217
+ prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
218
+ the tensor and returns the integer corresponding to this dtype based on the
219
+ enum in ScalarType.h
220
+ """
221
+ dtype = tensor.dtype
222
+ if dtype not in _TORCH_DTYPE_TO_ENUM:
223
+ raise RuntimeError(f"Unsupported dtype {dtype}")
224
+ return _TORCH_DTYPE_TO_ENUM[dtype]
225
+
226
+
227
+ # Those operators will be automatically populated to a instance method
228
+ # of TS2FXGraphConverter with name convert_<namespace>_<opname>().
229
+ # Please check __init__ for method population implementations.
230
+ kind_to_standard_operators = {
231
+ "prim::max": builtins.max,
232
+ "prim::min": builtins.min,
233
+ "prim::TupleIndex": operator.getitem,
234
+ "aten::__is__": operator.is_,
235
+ "aten::__isnot__": operator.is_not,
236
+ "aten::__not__": operator.not_,
237
+ "aten::__contains__": operator.contains,
238
+ "prim::dtype": get_dtype_as_int,
239
+ "aten::len": len,
240
+ # Mapping from specialized op to its symbolic counterpart.
241
+ # They currently do not have any other overrides.
242
+ "aten::numel": torch.ops.aten.sym_numel,
243
+ "aten::size": torch.ops.aten.sym_size,
244
+ "aten::storage_offset": torch.ops.aten.sym_storage_offset,
245
+ "aten::stride": torch.ops.aten.sym_stride,
246
+ }
247
+
248
+
249
+ def get_ir_value_parent_name_and_attr_name(node):
250
+ irv_parent_name, irv_name = node.input().debugName(), node.output().debugName()
251
+ attr_name = node.s("name")
252
+ return irv_name, irv_parent_name, attr_name
253
+
254
+
255
+ def construct_fqn(ir, ref_map, name_map):
256
+ name_list = []
257
+ while ir in ref_map:
258
+ name_list.append(name_map[ir])
259
+ ir = ref_map[ir]
260
+ return ".".join(reversed(name_list))
261
+
262
+
263
+ def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]:
264
+ """
265
+ Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
266
+ When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
267
+ each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model
268
+ parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model,
269
+ we will run this pass which will:
270
+ 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls.
271
+ 2. Process the graph bottom up to find the lifted attributes of each block by taking the union
272
+ of the attributes used in the current block, and the lifted attributes of all its child blocks.
273
+
274
+ Returns:
275
+ A mapping of blocks to a set of FQNs of its lifted attributes.
276
+ """
277
+
278
+ # A map from a block to its expected to be lifted arguments.
279
+ blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {}
280
+
281
+ # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
282
+ # GetAttr node. By traversing this reference map, we can figure out the
283
+ # full IR aliasing pass and figure out the FQN of an attribute.
284
+ # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
285
+ node_to_parent_map: Dict[str, str] = {}
286
+
287
+ # Used for reconstructing the FQN of an attribute based on the reference map.
288
+ # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
289
+ # This name map stores which attribute name is called for a src IR --> dest IR action.
290
+ # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
291
+ node_to_attr_name: Dict[str, str] = {}
292
+
293
+ def _dfs_get_attr_dependency(entry):
294
+ """
295
+ First DFS path to construct reference map and name map.
296
+ """
297
+ for node in entry.nodes():
298
+ if node.kind() == "prim::GetAttr":
299
+ (
300
+ irv_name,
301
+ irv_parent_name,
302
+ attr_name,
303
+ ) = get_ir_value_parent_name_and_attr_name(node)
304
+ node_to_parent_map[irv_name] = irv_parent_name
305
+ node_to_attr_name[irv_name] = attr_name
306
+ for block in node.blocks():
307
+ _dfs_get_attr_dependency(block)
308
+
309
+ def _map_blocks_to_lifted_attrs(entry):
310
+ """
311
+ Walk the graph in a bottom-up fashion to build the expected to be
312
+ lifted arguments for each block.
313
+ """
314
+ arguments: Set[str] = set()
315
+ for node in entry.nodes():
316
+ for block in node.blocks():
317
+ # Recursively build.
318
+ arguments = arguments.union(_map_blocks_to_lifted_attrs(block))
319
+ if node.kind() == "prim::GetAttr":
320
+ irv_name = node.output().debugName()
321
+ # Skip for intermediate GetAttr, which will anyway not result a FQN.
322
+ # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"}
323
+ # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"}
324
+ # There is only one FQN %3-->%2-->%1: self.linear.weight
325
+ # %2-->%1 is not a FQN: self.linear
326
+ if irv_name not in set(node_to_parent_map.values()):
327
+ arguments.add(
328
+ construct_fqn(irv_name, node_to_parent_map, node_to_attr_name)
329
+ )
330
+ if not isinstance(entry, torch._C.Graph): # Skip the top level.
331
+ blocks_to_lifted_attrs[entry] = arguments
332
+ return arguments
333
+
334
+ _dfs_get_attr_dependency(graph)
335
+ _map_blocks_to_lifted_attrs(graph)
336
+
337
+ return blocks_to_lifted_attrs
338
+
339
+
340
+ def get_attribute_fqn_from_ts_node(
341
+ name_to_attribute_fqn: Dict[str, str], node: torch._C.Node
342
+ ) -> str:
343
+ def get_attr(name: str):
344
+ if name in name_to_attribute_fqn:
345
+ return name_to_attribute_fqn[name]
346
+ else:
347
+ raise ValueError(f"Attribute {name} not found")
348
+
349
+ if node.kind() == "prim::SetAttr":
350
+ input_name = next(node.inputs()).debugName()
351
+ elif node.kind() == "prim::GetAttr":
352
+ input_name = node.input().debugName()
353
+ else:
354
+ raise RuntimeError(
355
+ f"Unexpected node kind when getting attribute fqn. node: {node} "
356
+ )
357
+
358
+ attr_name = node.s("name")
359
+ root_attr_name = get_attr(input_name)
360
+ attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
361
+
362
+ return attr_fqn
363
+
364
+
365
+ def get_op_overload(node: torch._C.Node):
366
+ schema_str = node.schema()
367
+ assert schema_str != "(no schema)", f"got empty schema for {node}"
368
+ schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str)
369
+ ns, op_name = str(schema.name).split("::")
370
+ override = schema.overload_name
371
+
372
+ try:
373
+ op_overload_mod = getattr(torch.ops, ns)
374
+ op_overload_packet = getattr(op_overload_mod, op_name)
375
+ if override:
376
+ op_overload = getattr(op_overload_packet, override)
377
+ else:
378
+ op_overload = op_overload_packet.default
379
+ except Exception as e:
380
+ raise RuntimeError(
381
+ f"Unable to find operator {node.kind()} with schema {node.schema()}"
382
+ ) from e
383
+
384
+ return op_overload
385
+
386
+
387
+ class TS2FXGraphConverter:
388
+ def __init__(
389
+ self,
390
+ ts_graph: Union[torch._C.Graph, torch._C.Block],
391
+ name_to_param: Dict[str, torch.Tensor],
392
+ name_to_buffer: Dict[str, torch.Tensor],
393
+ blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
394
+ name_to_non_tensor_attribute: Dict[str, Any],
395
+ name_to_constant: Dict[str, Any],
396
+ ):
397
+ self.ts_graph = ts_graph
398
+ self.name_to_param = name_to_param
399
+ self.name_to_buffer = name_to_buffer
400
+
401
+ self.fx_graph: torch.fx.Graph = torch.fx.Graph()
402
+ self.input_specs: List[InputSpec] = []
403
+ self.output_specs: List[OutputSpec] = []
404
+
405
+ self.name_to_node: Dict[
406
+ str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
407
+ ] = {}
408
+ self.name_to_constant: Dict[str, Any] = name_to_constant
409
+
410
+ # Mapping from torchscript node output name to attribute fully qualified name
411
+ self.name_to_attribute_fqn: Dict[str, str] = {}
412
+
413
+ # Mapping from fully qualified name to real values or a fx graph node
414
+ # During convert, this represents the current value of a non-tensor attribute
415
+ # One use case is:
416
+ # def forward(self, x):
417
+ # c1 = self.count
418
+ # self.count += 1
419
+ # c2 = self.count
420
+ # return x + c1 + c2
421
+ self.name_to_non_tensor_attribute_node: Dict[str, Any] = {}
422
+
423
+ # Mapping from fully qualified name to initial real values inputs
424
+ # We separate it from self.name_to_non_tensor_attribute_node since
425
+ # we need initial real value input when we construct fx.GraphModule
426
+ self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute
427
+
428
+ self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
429
+
430
+ self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
431
+
432
+ # Populate methods for the standard operators.
433
+ for k in kind_to_standard_operators.keys():
434
+ handler_func_name = ir_name_to_func_name(k)
435
+ # Create an indirect function call:
436
+ # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
437
+ setattr(
438
+ self,
439
+ handler_func_name,
440
+ lambda node: self._convert_standard_operators(node),
441
+ )
442
+
443
+ # This stores a list of return results that do not appear in the original TS
444
+ # graph's outputs. The reason we maintain this is because some operations in the sub-block
445
+ # might have inplace updates to the variable defined in the parent fx graph. After
446
+ # the execution of that sub-block, the variable defined in the parent fx graph also
447
+ # needs to be updated.
448
+ self.name_update_from_subblock_to_parent: Set[str] = set()
449
+
450
+ def _is_get_attr_node(self, fqn):
451
+ return (
452
+ fqn in self.name_to_buffer
453
+ or fqn in self.name_to_param
454
+ or (
455
+ fqn in self.name_to_constant
456
+ and isinstance(self.name_to_constant[fqn], torch.ScriptObject)
457
+ )
458
+ )
459
+
460
+ def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
461
+ subgraph_nodes, subgraph_converters = [], []
462
+ for block in node.blocks():
463
+ subgraph_converter = TS2FXGraphConverter(
464
+ block,
465
+ self.name_to_param,
466
+ self.name_to_buffer,
467
+ self.blocks_to_lifted_attrs,
468
+ {},
469
+ self.name_to_constant,
470
+ )
471
+ subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
472
+
473
+ for block_arg in arguments:
474
+ normalized_block_arg_name = normalize_name(block_arg)
475
+ placeholder_node = subgraph_converter.fx_graph.placeholder(
476
+ normalized_block_arg_name
477
+ )
478
+ subgraph_converter.name_to_node[block_arg] = placeholder_node
479
+
480
+ subgraph = subgraph_converter.convert()
481
+ subgraph_name = self.add_subgraph(subgraph)
482
+ subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
483
+ subgraph_converters.append(subgraph_converter)
484
+ return subgraph_nodes, subgraph_converters
485
+
486
+ def _identify_inputs_as_arguments(self, entry):
487
+ """
488
+ Identify inputs from the innermost sub-block. This is needed
489
+ for nested sub-blocks when the input is hidden in the nested sub-block.
490
+ E.g., example IR of input is hidden in the nested sub-block.
491
+ Graph[x.1]
492
+ %1 = ...
493
+ Block[]
494
+ Block[x.1]
495
+ %2 = x.1 ...
496
+ """
497
+ arguments: Set[str] = set()
498
+ for block in entry.blocks():
499
+ for block_node in block.nodes():
500
+ for block_node_in in block_node.inputs():
501
+ if (
502
+ block_node_in.debugName() in self.name_to_node
503
+ and block_node_in.debugName() not in self.name_to_attribute_fqn
504
+ ):
505
+ arguments.add(block_node_in.debugName())
506
+ arguments = arguments.union(
507
+ self._identify_inputs_as_arguments(block_node)
508
+ )
509
+ return arguments
510
+
511
+ def is_top_level_graph(self):
512
+ return isinstance(self.ts_graph, torch._C.Graph)
513
+
514
+ def add_subgraph(self, subgraph) -> str:
515
+ name = f"subgraph_{len(self.subgraphs)}"
516
+ self.subgraphs[name] = subgraph
517
+ return name
518
+
519
+ def get_args_kwargs(self, node: torch._C.Node, schema):
520
+ args = []
521
+ kwargs = {}
522
+ for input, schema_arg in zip(node.inputs(), schema.arguments):
523
+ if schema_arg.kwarg_only:
524
+ kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input)
525
+ else:
526
+ args.append(self.get_fx_value_by_ir_value(input))
527
+
528
+ return tuple(args), kwargs
529
+
530
+ def get_fx_value_by_ir_value(self, value: torch._C.Value):
531
+ value_name = value.debugName()
532
+
533
+ if value_name in self.name_to_node:
534
+ input_node = self.name_to_node[value_name]
535
+ return input_node
536
+ elif value_name in self.name_to_constant:
537
+ if isinstance(self.name_to_constant[value_name], torch.ScriptObject):
538
+ return self.fx_graph.get_attr(value_name)
539
+ return self.name_to_constant[value_name]
540
+ else:
541
+ raise ValueError(f"Input {value_name} not found")
542
+
543
+ def get_fx_value_by_fqn(self, name):
544
+ if name in self.name_to_node:
545
+ fx_node = self.name_to_node[name]
546
+ elif name in self.name_to_constant:
547
+ fx_node = self.name_to_constant[name]
548
+ elif name in self.name_to_non_tensor_attribute_node:
549
+ fx_node = self.name_to_non_tensor_attribute_node[name]
550
+ elif name in self.name_to_non_tensor_attribute:
551
+ fx_node = self.name_to_non_tensor_attribute[name]
552
+ else:
553
+ raise ValueError(f"Attribute {name} not found")
554
+ return fx_node
555
+
556
+ def convert(self) -> torch.fx.GraphModule:
557
+ self.convert_graph_inputs()
558
+
559
+ for node in self.ts_graph.nodes():
560
+ self.convert_node(node)
561
+
562
+ self.convert_graph_outputs()
563
+
564
+ # Pass parameter and buffer to the root for lookup.
565
+ gm = torch.fx.GraphModule(
566
+ {
567
+ **self.subgraphs,
568
+ **self.name_to_param,
569
+ **self.name_to_buffer,
570
+ **self.name_to_non_tensor_attribute,
571
+ **self.name_to_constant,
572
+ },
573
+ self.fx_graph,
574
+ )
575
+
576
+ inplace_optimize_sym_size_div(gm)
577
+
578
+ gm.graph.lint()
579
+
580
+ return gm
581
+
582
+ def convert_graph_inputs(self):
583
+ for graph_input in self.ts_graph.inputs():
584
+ name = graph_input.debugName()
585
+
586
+ if name in self.name_to_param:
587
+ normalized_name = normalize_name(name)
588
+ self.input_specs.append(
589
+ InputSpec(
590
+ InputKind.PARAMETER,
591
+ arg=TensorArgument(name=normalized_name),
592
+ target=name,
593
+ )
594
+ )
595
+ fx_node = get_node_as_placeholder_or_get_attr(
596
+ self.fx_graph, name, self.is_top_level_graph()
597
+ )
598
+ elif name in self.name_to_buffer:
599
+ normalized_name = normalize_name(name)
600
+ self.input_specs.append(
601
+ InputSpec(
602
+ InputKind.BUFFER,
603
+ arg=TensorArgument(name=normalized_name),
604
+ target=name,
605
+ persistent=True,
606
+ )
607
+ )
608
+ fx_node = get_node_as_placeholder_or_get_attr(
609
+ self.fx_graph, name, self.is_top_level_graph()
610
+ )
611
+ elif name in self.name_to_constant:
612
+ assert isinstance(
613
+ self.name_to_constant[name], torch.ScriptObject
614
+ ), "Input conversion only handles ScriptObject"
615
+ normalized_name = normalize_name(name)
616
+ self.input_specs.append(
617
+ InputSpec(
618
+ InputKind.CUSTOM_OBJ,
619
+ arg=CustomObjArgument(
620
+ name=normalized_name, class_fqn=normalized_name
621
+ ),
622
+ target=name,
623
+ persistent=False,
624
+ )
625
+ )
626
+ fx_node = get_node_as_placeholder_or_get_attr(
627
+ self.fx_graph, name, self.is_top_level_graph()
628
+ )
629
+ elif isinstance(graph_input.type(), torch.ClassType):
630
+ # Directly skip inputs that are ScriptObject but not used in the graph.
631
+ continue
632
+ else:
633
+ normalized_name = normalize_name(name, prefix="input")
634
+ self.input_specs.append(
635
+ InputSpec(
636
+ InputKind.USER_INPUT,
637
+ arg=TensorArgument(name=normalized_name),
638
+ target=name,
639
+ )
640
+ )
641
+ fx_node = self.fx_graph.placeholder(normalized_name)
642
+
643
+ self.name_to_node[name] = fx_node
644
+
645
+ def convert_aten_Float(self, node: torch._C.Node):
646
+ def to_float_tensor(t):
647
+ return t.to(dtype=torch.float).item()
648
+
649
+ inp_list = [
650
+ self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
651
+ ] # noqa: C416
652
+ fx_node = self.fx_graph.call_function(
653
+ to_float_tensor,
654
+ tuple(inp_list),
655
+ )
656
+ self.name_to_node[node.output().debugName()] = fx_node
657
+
658
+ def convert_aten_tensor(self, node: torch._C.Node):
659
+ """aten::tensor creates a constant tensor ad-hoc --> GetAttr"""
660
+ args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema)
661
+
662
+ for k in kwargs:
663
+ if k == "requires_grad":
664
+ kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True
665
+
666
+ to_tensor = (
667
+ torch.tensor
668
+ if all(isinstance(a, int) for a in args)
669
+ else torch._refs.tensor
670
+ )
671
+
672
+ def target(*args, **kwargs):
673
+ if "dtype" in kwargs and kwargs["dtype"] is not None:
674
+ kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
675
+ return to_tensor(*args, **kwargs)
676
+
677
+ # def to_dynamic_tensor(*args, **kwargs):
678
+ # if "dtype" in kwargs and kwargs["dtype"] is not None:
679
+ # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]]
680
+ # return torch._refs.tensor(*args, **kwargs)
681
+
682
+ output_name = node.output().debugName()
683
+ fx_node = self.fx_graph.call_function(target, args, kwargs)
684
+ self.name_to_node[output_name] = fx_node
685
+
686
+ def convert_aten_append(self, node: torch._C.Node):
687
+ # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"
688
+
689
+ # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list
690
+ # This makes the converter "non-functional", and the result depends on the order of the nodes being converter
691
+ # In a sense, the converter now becomes an stateful interpreter
692
+ warnings.warn(
693
+ "Converting aten::append.t, which is a inplace mutation of the list. "
694
+ "This makes the converter non-functional: the result depends on the order of the append nodes being converter!"
695
+ )
696
+
697
+ args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs())
698
+ fx_node = self.fx_graph.call_function(list_append, args)
699
+ self.name_to_node[node.output().debugName()] = fx_node
700
+
701
+ # inplace mutate arg[0], which is the python list
702
+ self.name_to_node[node.inputsAt(0).debugName()] = fx_node
703
+
704
+ # Variables that need to be updated to parent module.
705
+ if not self.is_top_level_graph() and args[0].op == "placeholder":
706
+ self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName())
707
+
708
+ def convert_prim_Constant(self, node: torch._C.Node):
709
+ name = node.output().debugName()
710
+
711
+ value: Any = None
712
+ if node.hasAttribute("value"):
713
+ constant_kind = node.kindOf("value")
714
+ if constant_kind == "i":
715
+ value = node.i("value")
716
+ elif constant_kind == "f":
717
+ value = node.f("value")
718
+ elif constant_kind == "s":
719
+ value = node.s("value")
720
+ elif constant_kind == "t":
721
+ alias_name = (
722
+ f"lifted_tensor_{name}" # Follow naming convention from EP tracing.
723
+ )
724
+ fx_node = self.fx_graph.get_attr(alias_name)
725
+ self.name_to_node[name] = fx_node
726
+ name, value = alias_name, node.t("value")
727
+ elif constant_kind == "ival":
728
+ value = node.ival("value")
729
+ else:
730
+ raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
731
+ else:
732
+ value = None
733
+
734
+ self.name_to_constant[name] = value
735
+
736
+ def convert_prim_CallMethod(self, node: torch._C.Node):
737
+ inp_list = [
738
+ self.get_fx_value_by_ir_value(inp) for inp in node.inputs()
739
+ ] # noqa: C416
740
+ fx_node = self.fx_graph.call_method(
741
+ node.s("name"),
742
+ tuple(inp_list),
743
+ )
744
+ self.name_to_node[node.output().debugName()] = fx_node
745
+
746
+ def convert_prim_device(self, node: torch._C.Node):
747
+ input_type = node.input().type()
748
+ if input_type.isSubtypeOf(torch._C.TensorType.get()):
749
+ device = input_type.device() # type: ignore[attr-defined]
750
+ output_name = node.output().debugName()
751
+ self.name_to_constant[output_name] = device
752
+ else:
753
+ raise ValueError(f"Unsupported JitType ({input_type}) when get device")
754
+
755
+ def convert_prim_GetAttr(self, node: torch._C.Node):
756
+ # Build fully qulified name
757
+ attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
758
+ output_name = node.output().debugName()
759
+ self.name_to_attribute_fqn[output_name] = attr_fqn
760
+
761
+ if self.is_top_level_graph():
762
+ if self._is_get_attr_node(attr_fqn):
763
+ # We insert a get_attr node due to two reasons.
764
+ # First, ts graph does not lift tensor constants as input nodes. So
765
+ # tensor constants may be ignored by in convert_graph_inputs().
766
+ # Second, attr_fqn may have been written to via SetAttr. Two
767
+ # GetAttr may give different values.
768
+ self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn)
769
+ else:
770
+ if attr_fqn not in self.name_to_non_tensor_attribute_node:
771
+ self.name_to_non_tensor_attribute_node[
772
+ attr_fqn
773
+ ] = self.name_to_non_tensor_attribute[attr_fqn]
774
+ self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[
775
+ attr_fqn
776
+ ]
777
+ else:
778
+ # Special support for if blocks which do not allow SetAttr TorchScript
779
+ # node and get_attr FX Graph Node.
780
+ if self._is_get_attr_node(attr_fqn):
781
+ self.name_to_node[output_name] = self.name_to_node[attr_fqn]
782
+
783
+ def convert_prim_SetAttr(self, node: torch._C.Node):
784
+ attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node)
785
+ attr_value = tuple(node.inputs())[1]
786
+ ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value)
787
+ if self._is_get_attr_node(attr_fqn):
788
+ fx_attr_node = self.fx_graph.get_attr(attr_fqn)
789
+ self.fx_graph.call_function(
790
+ torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input)
791
+ )
792
+ else:
793
+ self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input
794
+
795
+ def convert_call_function_op(self, node: torch._C.Node):
796
+ target = get_op_overload(node)
797
+
798
+ args, kwargs = self.get_args_kwargs(node, target._schema)
799
+
800
+ fx_node = self.fx_graph.call_function(target, args, kwargs)
801
+
802
+ # TODO: covnert sourceRange() into stack_trace
803
+ # fx_node.meta["stack_trace"] = node.sourceRange()
804
+
805
+ if node.outputsSize() == 1:
806
+ output_name = node.output().debugName()
807
+ self.name_to_node[output_name] = fx_node
808
+ else:
809
+ for i, outp in enumerate(node.outputs()):
810
+ output_name = outp.debugName()
811
+ next_fx_node = self.fx_graph.call_function(
812
+ operator.getitem, (fx_node, i)
813
+ )
814
+ self.name_to_node[output_name] = next_fx_node
815
+
816
+ def convert_prim_TupleConstruct(self, node: torch._C.Node):
817
+ self._convert_prim_iterator(node)
818
+
819
+ def convert_prim_ListConstruct(self, node: torch._C.Node):
820
+ self._convert_prim_iterator(node)
821
+
822
+ def _convert_prim_iterator(self, node: torch._C.Node):
823
+ output_list = []
824
+ for inp in node.inputs():
825
+ output_list.append(self.get_fx_value_by_ir_value(inp))
826
+
827
+ output_name = node.output().debugName()
828
+ self.name_to_node[output_name] = output_list
829
+
830
+ def convert_prim_DictConstruct(self, node: torch._C.Node):
831
+ output_dict = {}
832
+ k, v = None, None
833
+ for i, inp in enumerate(node.inputs()):
834
+ # We assume key value are stored in pair in the DictConstruct.
835
+ # The first element is the key and the following is the value.
836
+ if i % 2 == 0:
837
+ k = self.get_fx_value_by_ir_value(inp)
838
+ else:
839
+ v = self.get_fx_value_by_ir_value(inp)
840
+ assert (
841
+ k is not None and v is not None
842
+ ), "DictConstruct has an empty key value pair."
843
+ output_dict[k] = v
844
+ k, v = None, None
845
+
846
+ assert (
847
+ k is None and v is None
848
+ ), "DictConstruct has an odd number of elements (violating our assumption)."
849
+
850
+ output_name = node.output().debugName()
851
+ self.name_to_node[output_name] = output_dict
852
+
853
+ def convert_prim_ListUnpack(self, node: torch._C.Node):
854
+ self._convert_prim_unpack_iterator(node)
855
+
856
+ def convert_prim_TupleUnpack(self, node: torch._C.Node):
857
+ self._convert_prim_unpack_iterator(node)
858
+
859
+ def _convert_prim_unpack_iterator(self, node: torch._C.Node):
860
+ # Single input and multiple outputs for unpacking.
861
+ for i, outp in enumerate(node.outputs()):
862
+ outp_name = outp.debugName()
863
+ inp = self.get_fx_value_by_ir_value(node.input())
864
+ fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
865
+ self.name_to_node[outp_name] = fx_node
866
+
867
+ def convert_aten_Int(self, node: torch._C.Node):
868
+ # converts aten::Int as aten._to_copy + aten::_local_scalar_dense
869
+ target = torch.ops.aten._to_copy.default
870
+ args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
871
+ to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
872
+
873
+ fx_node = self.fx_graph.call_function(
874
+ torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
875
+ )
876
+
877
+ # TODO: covnert sourceRange() into stack_trace
878
+ # fx_node.meta["stack_trace"] = node.sourceRange()
879
+
880
+ output_name = node.output().debugName()
881
+ self.name_to_node[output_name] = fx_node
882
+
883
+ def convert_prim_NumToTensor(self, node: torch._C.Node):
884
+ # Converts prim::NumToTensor as aten.scalar_tensor.
885
+ # prim::NumToTensor IRs are currently triggered by:
886
+ # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950
887
+ # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971
888
+ # For both of those APIs, torch.jit.trace implicitly sets the output tensor type
889
+ # to be LongTensor.
890
+ target = torch.ops.aten.scalar_tensor
891
+ args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
892
+
893
+ fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long})
894
+ output_name = node.output().debugName()
895
+ self.name_to_node[output_name] = fx_node
896
+
897
+ def convert_prim_CreateObject(self, node: torch._C.Node):
898
+ output_name = node.output().debugName()
899
+ self.name_to_attribute_fqn[output_name] = ""
900
+
901
+ def convert_aten__convolution(self, node: torch._C.Node):
902
+ # converts aten::_convolution as aten.convolution, since aten::_convolution
903
+ # doesn't have a meta function
904
+ target = torch.ops.aten.convolution.default
905
+ args, kwargs = self.get_args_kwargs(node, target._schema)
906
+
907
+ fx_node = self.fx_graph.call_function(target, args, kwargs)
908
+
909
+ output_name = node.output().debugName()
910
+ self.name_to_node[output_name] = fx_node
911
+
912
+ def convert_aten_div(self, node: torch._C.Node):
913
+ target = get_op_overload(node)
914
+ schema = target._schema
915
+
916
+ args, kwargs = self.get_args_kwargs(node, schema)
917
+
918
+ # converts aten::div.Tensor_mode(x, tensor_constant)
919
+ # as aten.div.Scalar_mode(x, tensor_constant.item())
920
+ if schema.overload_name == "Tensor_mode":
921
+ arg1_name = args[1].name
922
+ if arg1_name in self.name_to_constant and isinstance(
923
+ self.name_to_constant[arg1_name], torch.Tensor
924
+ ):
925
+ tensor_constant = self.name_to_constant[arg1_name]
926
+ if tensor_constant.numel() == 1:
927
+ updated_args = list(args)
928
+ updated_args[1] = self.name_to_constant[arg1_name].item()
929
+
930
+ fx_node = self.fx_graph.call_function(
931
+ torch.ops.aten.div.Scalar_mode,
932
+ tuple(updated_args),
933
+ kwargs,
934
+ )
935
+
936
+ # TODO: covnert sourceRange() into stack_trace
937
+ # fx_node.meta["stack_trace"] = node.sourceRange()
938
+
939
+ output_name = node.output().debugName()
940
+ self.name_to_node[output_name] = fx_node
941
+ return
942
+
943
+ self.convert_call_function_op(node)
944
+
945
+ def convert_aten___getitem__(self, node: torch._C.Node):
946
+ input_container, index = tuple(
947
+ self.get_fx_value_by_ir_value(input) for input in node.inputs()
948
+ )
949
+ fx_node = self.fx_graph.call_function(
950
+ operator.getitem, (input_container, index)
951
+ )
952
+ output_name = node.output().debugName()
953
+ self.name_to_node[output_name] = fx_node
954
+
955
+ def convert_aten_to(self, node: torch._C.Node):
956
+ target = get_op_overload(node)
957
+ args, kwargs = self.get_args_kwargs(node, target._schema)
958
+
959
+ # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op
960
+ # coz aten.to + inplace_mutation_op pattern would trigger
961
+ # "cannot mutate tensors with frozen storage" functionalization error.
962
+ # To work around the issue, we override the copy to be True, so that the output
963
+ # is for sure not an alias of input
964
+ if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype:
965
+ user_nodes = [use.user for use in node.output().uses()]
966
+ user_targets = [
967
+ get_op_overload(user_node)
968
+ for user_node in user_nodes
969
+ if user_node.schema() != "(no schema)"
970
+ ]
971
+ has_mutable_target = any(
972
+ target._schema.is_mutable for target in user_targets
973
+ )
974
+
975
+ if has_mutable_target:
976
+ assert len(args) >= 4
977
+ new_args = list(args)
978
+ new_args[3] = True # copy, override to True
979
+ fx_node = self.fx_graph.call_function(
980
+ torch.ops.aten.to.dtype, tuple(new_args)
981
+ )
982
+ # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679
983
+ # When this issue is fixed, the clone node would be no longer needed
984
+ clone_node = self.fx_graph.call_function(
985
+ torch.ops.aten.clone.default, (fx_node,)
986
+ )
987
+ output_name = node.output().debugName()
988
+ self.name_to_node[output_name] = clone_node
989
+ return
990
+
991
+ self.convert_call_function_op(node)
992
+
993
+ def convert_aten_add(self, node: torch._C.Node):
994
+ if node.schema() == "(no schema)":
995
+ if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance(
996
+ node.inputsAt(1).type(), torch.ListType
997
+ ):
998
+ target = torch.ops.aten.add.t
999
+ else:
1000
+ raise RuntimeError(f"unable to determind the target for {node}")
1001
+ else:
1002
+ target = get_op_overload(node)
1003
+
1004
+ if target == torch.ops.aten.add.t:
1005
+ # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for
1006
+ # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'.
1007
+ args, kwargs = self.get_args_kwargs(node, target._schema)
1008
+ output_name = node.output().debugName()
1009
+ self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args)
1010
+ else:
1011
+ self.convert_call_function_op(node)
1012
+
1013
+ def _check_prim_loop_support(self, node):
1014
+ inputs = list(node.inputs())
1015
+
1016
+ # TODO: (1/N) stage.
1017
+ if inputs[0].debugName() not in self.name_to_constant:
1018
+ raise RuntimeError(
1019
+ "prim::Loop currently cannot run with dynamic value of number of iterations."
1020
+ )
1021
+
1022
+ # Make sure the condition is not updated in the subblock.
1023
+ subblock = next(node.blocks())
1024
+ condition_output_name = next(subblock.outputs()).debugName()
1025
+ for node in subblock.nodes():
1026
+ if (
1027
+ node.outputsSize() == 1
1028
+ and node.output().debugName() == condition_output_name
1029
+ ):
1030
+ raise RuntimeError(
1031
+ "prim::Loop currently cannot run with dynamic value of condition."
1032
+ )
1033
+ if node.outputsSize() >= 2:
1034
+ for outp in node.outputs():
1035
+ if outp.debugName() == condition_output_name:
1036
+ raise RuntimeError(
1037
+ "prim::Loop currently cannot run with dynamic value of condition."
1038
+ )
1039
+
1040
+ def convert_prim_Loop(self, node: torch._C.Node):
1041
+ inputs = list(node.inputs())
1042
+ self._check_prim_loop_support(node)
1043
+
1044
+ num_iterations = self.get_fx_value_by_ir_value(inputs[0])
1045
+
1046
+ # Find inputs.
1047
+ loop_local_arguments = [inp.debugName() for inp in inputs[2:]]
1048
+
1049
+ global_arguments = self._identify_inputs_as_arguments(node)
1050
+
1051
+ # Lift parameters as inputs.
1052
+ for block in node.blocks():
1053
+ global_arguments = global_arguments.union(
1054
+ self.blocks_to_lifted_attrs[block]
1055
+ )
1056
+
1057
+ global_arguments = list(global_arguments)
1058
+
1059
+ subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph(
1060
+ node, global_arguments
1061
+ )
1062
+
1063
+ assert len(subgraph_nodes) == 1
1064
+ subgraph_converter = subgraph_converters[0]
1065
+ if not self.is_top_level_graph():
1066
+ self.name_update_from_subblock_to_parent = (
1067
+ self.name_update_from_subblock_to_parent.union(
1068
+ subgraph_converter.name_update_from_subblock_to_parent
1069
+ )
1070
+ )
1071
+
1072
+ fx_block_args = [
1073
+ self.get_fx_value_by_fqn(name)
1074
+ for name in loop_local_arguments + global_arguments
1075
+ ]
1076
+ for iter_idx in range(num_iterations):
1077
+ loop_node = self.fx_graph.call_function(
1078
+ execute_subgraph_from_prim_loop,
1079
+ # Check execute_node function for the expected arguments order.
1080
+ (
1081
+ subgraph_nodes[0],
1082
+ iter_idx,
1083
+ len(loop_local_arguments),
1084
+ *fx_block_args,
1085
+ ),
1086
+ {},
1087
+ )
1088
+
1089
+ # Update the value of loop local variables.
1090
+ if node.outputsSize() >= 1:
1091
+ for i, outp in enumerate(node.outputs()):
1092
+ output_name = outp.debugName()
1093
+ self.name_to_node[output_name] = self.fx_graph.call_function(
1094
+ operator.getitem,
1095
+ (
1096
+ loop_node,
1097
+ i + 1,
1098
+ ), # + 1 because the 0th element is the condition.
1099
+ )
1100
+ fx_block_args[i] = self.name_to_node[output_name]
1101
+
1102
+ # Update the value of global variables, whose values are modified inplace.
1103
+ for i, name in enumerate(
1104
+ subgraph_converter.name_update_from_subblock_to_parent
1105
+ ):
1106
+ self.name_to_node[name] = self.fx_graph.call_function(
1107
+ operator.getitem,
1108
+ (
1109
+ loop_node,
1110
+ i + node.outputsSize() + 1,
1111
+ ), # + 1 because the 0th element is the condition.
1112
+ )
1113
+ global_argument_index = global_arguments.index(name)
1114
+ fx_block_args[
1115
+ i + node.outputsSize() + global_argument_index
1116
+ ] = self.name_to_node[name]
1117
+
1118
+ def _check_set_attr_in_if_block(self, if_node: torch._C.Node):
1119
+ for block in if_node.blocks():
1120
+ for node in block.nodes():
1121
+ if node.kind() == "prim::SetAttr":
1122
+ raise RuntimeError(
1123
+ "During converting prim::If to torch.cond, found prim::SetAttr op"
1124
+ " which is not supported yet. Please file an issue if you come "
1125
+ "across this error."
1126
+ )
1127
+
1128
+ def convert_prim_If(self, node: torch._C.Node):
1129
+ self._check_set_attr_in_if_block(node)
1130
+
1131
+ inputs = list(node.inputs())
1132
+ assert len(inputs) == 1
1133
+ predicate = self.get_fx_value_by_ir_value(inputs[0])
1134
+
1135
+ # Find inputs.
1136
+ arguments = self._identify_inputs_as_arguments(node)
1137
+
1138
+ # Lift parameters as inputs.
1139
+ for block in node.blocks():
1140
+ arguments = arguments.union(self.blocks_to_lifted_attrs[block])
1141
+
1142
+ arguments = list(arguments)
1143
+ subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments)
1144
+
1145
+ assert len(subgraph_nodes) == 2
1146
+
1147
+ fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments]
1148
+
1149
+ args = (
1150
+ predicate,
1151
+ subgraph_nodes[0],
1152
+ subgraph_nodes[1],
1153
+ tuple(fx_block_args),
1154
+ )
1155
+
1156
+ cond_node = self.fx_graph.call_function(torch.cond, args, {})
1157
+
1158
+ # prim::If can also have zero output.
1159
+ if node.outputsSize() == 1:
1160
+ output_name = node.output().debugName()
1161
+ self.name_to_node[output_name] = cond_node
1162
+ elif node.outputsSize() > 1:
1163
+ for i, output in enumerate(node.outputs()):
1164
+ output_name = output.debugName()
1165
+ getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i))
1166
+ self.name_to_node[output_name] = getitem
1167
+
1168
+ def convert_aten_Bool(self, node: torch._C.Node):
1169
+ self._convert_as_noop(node)
1170
+
1171
+ def convert_prim_Enter(self, node: torch._C.Node):
1172
+ # export generally treats prim::Enter as noop
1173
+ # The only context manager export supports is aten::enable_grad.
1174
+ # Unfortunately, TorchScript does not support aten::enable_grad yet.
1175
+ # TODO: support aten::enable_grad in both TorchScript and Converter.
1176
+ return
1177
+
1178
+ def convert_prim_Exit(self, node: torch._C.Node):
1179
+ # export treats prim::Exit as noop
1180
+ return
1181
+
1182
+ def _convert_as_noop(self, node: torch._C.Node):
1183
+ # Converts the node as a no-op by mapping its output node as arg[0]
1184
+
1185
+ target = get_op_overload(node)
1186
+ schema = target._schema
1187
+
1188
+ args, kwargs = self.get_args_kwargs(node, schema)
1189
+
1190
+ output_name = node.output().debugName()
1191
+ self.name_to_node[output_name] = args[0]
1192
+
1193
+ def convert_profiler__record_function_exit(self, node: torch._C.Node):
1194
+ # _record_function_exit has side effect so we keep it in fx.graph
1195
+ # currently, _record_function_enter_new and _record_function_exit are
1196
+ # discarded during `retrace_as_exported_program`.
1197
+ target = torch.ops.profiler._record_function_exit
1198
+ args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
1199
+ self.fx_graph.call_function(target, args)
1200
+
1201
+ def convert_prim_tolist(self, node: torch._C.Node):
1202
+ # prim::tolist cannot be supported by `_convert_standard_operators`
1203
+ # since it requires call_method instead of call_function.
1204
+ target = "tolist"
1205
+ args = (self.get_fx_value_by_ir_value(next(node.inputs())),)
1206
+ fx_node = self.fx_graph.call_method(target, args)
1207
+ output_name = node.output().debugName()
1208
+ self.name_to_node[output_name] = fx_node
1209
+
1210
+ def convert_prim_Uninitialized(self, node: torch._C.Node):
1211
+ # `prim::Uninitialized` is inserted by the compiler when it can prove
1212
+ # the value will never be used. It can be introduced by exceptions,
1213
+ # breaks, continues, and returns.
1214
+ # So we add a dummy constant to the graph.
1215
+ output_name = node.output().debugName()
1216
+ self.name_to_constant[output_name] = torch.Tensor()
1217
+
1218
+ def _convert_standard_operators(self, node: torch._C.Node):
1219
+ target = kind_to_standard_operators[node.kind()]
1220
+ args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs())
1221
+ fx_node = self.fx_graph.call_function(target, args)
1222
+ output_name = node.output().debugName()
1223
+ self.name_to_node[output_name] = fx_node
1224
+
1225
+ def convert_node(self, node: torch._C.Node):
1226
+ node_kind = node.kind()
1227
+
1228
+ # Get handler based on namespace and operator name.
1229
+ # Provide a default node handler as well in case we don't find
1230
+ # matching converter for that.
1231
+ handler_func_name = ir_name_to_func_name(node_kind)
1232
+ handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
1233
+
1234
+ # str calls print function implemented in CPP. To avoid repeating
1235
+ # the entire logic here, we simply keep first line from node string (getting rid
1236
+ # of sub-blocks IR prints).
1237
+ node_str = "".join(str(node).split("\n")[:1])
1238
+ log.debug("[%s] converts [%s]", handler_func.__name__, node_str)
1239
+ try:
1240
+ handler_func(node)
1241
+ except Exception as e:
1242
+ raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e
1243
+
1244
+ def convert_graph_outputs(self):
1245
+ args = []
1246
+ outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list(
1247
+ self.name_update_from_subblock_to_parent
1248
+ )
1249
+ for output_name in outp_name_list:
1250
+ if output_name in self.name_to_node:
1251
+ fx_node = self.name_to_node[output_name]
1252
+ # TODO: Revisit this later after HigherOrderOp design changes.
1253
+ # Currently, we cannot directly return input as output.
1254
+ if (
1255
+ not self.is_top_level_graph()
1256
+ and isinstance(fx_node, torch.fx.Node)
1257
+ and fx_node.op == "placeholder"
1258
+ ):
1259
+ fx_node = self.fx_graph.call_function(torch.clone, (fx_node,))
1260
+ args.append(fx_node)
1261
+ self.output_specs.append(
1262
+ OutputSpec(
1263
+ OutputKind.USER_OUTPUT,
1264
+ arg=TensorArgument(name=output_name),
1265
+ target=output_name,
1266
+ )
1267
+ )
1268
+ elif output_name in self.name_to_constant:
1269
+ args.append(self.name_to_constant[output_name])
1270
+ self.output_specs.append(
1271
+ OutputSpec(
1272
+ OutputKind.USER_OUTPUT,
1273
+ arg=ConstantArgument(
1274
+ name=output_name, value=self.name_to_constant[output_name]
1275
+ ),
1276
+ target=output_name,
1277
+ )
1278
+ )
1279
+ else:
1280
+ raise ValueError(f"Output {output_name} not found")
1281
+
1282
+ if len(args) == 0:
1283
+ # Sub-block of prim::If can have zero output.
1284
+ self.fx_graph.output([])
1285
+ elif len(args) == 1:
1286
+ self.fx_graph.output(
1287
+ args[0]
1288
+ ) # Get rid of an extra list wrapped around final output.
1289
+ elif len(args) > 1:
1290
+ self.fx_graph.output(
1291
+ args
1292
+ ) # For prim::Loop and prim::If with multiple outputs.
1293
+ else:
1294
+ # Sub-block of prim::Loop can have multiple outputs.
1295
+ self.fx_graph.output(args)
1296
+
1297
+
1298
+ class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
1299
+ """
1300
+ Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions
1301
+ and provide that information to users. In order to collect all failed conversions, it
1302
+ also mocks some internal attributes (e.g., name_to_node).
1303
+ """
1304
+
1305
+ class _DictMock(dict):
1306
+ def __init__(self, dict_data, mock_value):
1307
+ super().__init__(dict_data)
1308
+ self.mock_value = mock_value
1309
+
1310
+ def __getitem__(self, key):
1311
+ # If the original dictionary has the key, return its value.
1312
+ # Otherwise, return the mock value.
1313
+ if not super().__contains__(key):
1314
+ return self.mock_value
1315
+ return super().__getitem__(key)
1316
+
1317
+ def __contains__(self, key):
1318
+ return True
1319
+
1320
+ def __init__(
1321
+ self,
1322
+ ts_graph: Union[torch._C.Graph, torch._C.Block],
1323
+ name_to_param: Dict[str, torch.Tensor],
1324
+ name_to_buffer: Dict[str, torch.Tensor],
1325
+ blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
1326
+ name_to_non_tensor_attribute: Dict[str, Any],
1327
+ name_to_constant: Dict[str, Any],
1328
+ ):
1329
+ super().__init__(
1330
+ ts_graph,
1331
+ name_to_param,
1332
+ name_to_buffer,
1333
+ blocks_to_lifted_attrs,
1334
+ name_to_non_tensor_attribute,
1335
+ name_to_constant,
1336
+ )
1337
+
1338
+ # Data to keep track of unsupported nodes.
1339
+ self.unsupported_node_list: List[torch._C.Node] = []
1340
+
1341
+ # Add mock to needed attributes.
1342
+ self.name_to_node = ExplainTS2FXGraphConverter._DictMock(
1343
+ self.name_to_node,
1344
+ # Dummy node.
1345
+ torch.fx.Node(
1346
+ None, # type: ignore[arg-type]
1347
+ "mock",
1348
+ "call_function",
1349
+ lambda: None,
1350
+ (),
1351
+ {},
1352
+ ),
1353
+ )
1354
+
1355
+ def explain(self):
1356
+ self.convert_graph_inputs()
1357
+ for node in self.ts_graph.nodes():
1358
+ self.convert_node(node)
1359
+ self.convert_graph_outputs()
1360
+
1361
+ def convert_node(self, node):
1362
+ try:
1363
+ super().convert_node(node)
1364
+ except Exception as e:
1365
+ self.unsupported_node_list.append(node)
1366
+
1367
+
1368
+ @contextmanager
1369
+ def disable_logging(log):
1370
+ disabled = log.disabled
1371
+ log.disabled = True
1372
+ try:
1373
+ yield
1374
+ finally:
1375
+ log.disabled = disabled
1376
+
1377
+
1378
+ class TS2EPConverter:
1379
+ # TorchScript model to ExportedProgram converter
1380
+ def __init__(
1381
+ self,
1382
+ ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
1383
+ sample_args: Tuple[Any, ...],
1384
+ sample_kwargs: Optional[Dict[str, Any]] = None,
1385
+ ):
1386
+ self.ts_model = ts_model
1387
+ self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
1388
+
1389
+ self.sample_args = sample_args
1390
+ self.sample_kwargs = sample_kwargs
1391
+
1392
+ self.name_to_param: Dict[str, torch.Tensor] = {}
1393
+ self.name_to_buffer: Dict[str, torch.Tensor] = {}
1394
+ param_list = (
1395
+ list(self.ts_model.parameters())
1396
+ if not isinstance(self.ts_model, torch._C.ScriptFunction)
1397
+ else []
1398
+ )
1399
+ if not isinstance(self.ts_model, torch._C.ScriptFunction):
1400
+ for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
1401
+ # Check if tensor belongs to any parameter.
1402
+ if any(
1403
+ (tensor == param).all()
1404
+ for param in param_list
1405
+ if tensor.shape == param.shape
1406
+ ):
1407
+ self.name_to_param[k] = tensor
1408
+ else:
1409
+ self.name_to_buffer[k] = tensor
1410
+
1411
+ self.name_to_non_tensor_attributes: Dict[str, Any] = {}
1412
+ self.name_to_constant: Dict[str, Any] = {}
1413
+
1414
+ self.lift_get_attr()
1415
+
1416
+ def convert(self) -> ExportedProgram:
1417
+ log.info(
1418
+ """
1419
+ TS2EPConverter logging starts from here.
1420
+
1421
+ INFO: (TORCH_LOGS="export" <cmd>)
1422
+ * Log TorchScript IR.
1423
+
1424
+ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
1425
+ * Log conversion IR by IR in a format of [<conversion handler name>] converts [<IR>].
1426
+ """
1427
+ )
1428
+ log.info("TorchScript graph\n\n%s\n", self.ts_graph)
1429
+
1430
+ blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1431
+
1432
+ graph_converter = TS2FXGraphConverter(
1433
+ self.ts_graph,
1434
+ self.name_to_param,
1435
+ self.name_to_buffer,
1436
+ blocks_to_lifted_attrs,
1437
+ self.name_to_non_tensor_attributes,
1438
+ self.name_to_constant,
1439
+ )
1440
+ gm = graph_converter.convert()
1441
+
1442
+ # Post-proccessing step to deal with quantized operators.
1443
+ replace_quantized_ops_with_standard_ops(gm)
1444
+ log.info("GraphModule: %s", gm.print_readable(print_output=False))
1445
+
1446
+ ep = self.retrace_as_exported_program(
1447
+ gm,
1448
+ graph_converter.name_to_constant,
1449
+ )
1450
+ log.info("%s", ep)
1451
+
1452
+ # Post-processing step to ensure ExportedProgram has the same state_dict as
1453
+ # the original TorchScript model. Throw warnings for additionally populated
1454
+ # state_dict entries.
1455
+ if not isinstance(self.ts_model, torch._C.ScriptFunction):
1456
+ for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr]
1457
+ if k not in ep.state_dict:
1458
+ warnings.warn(
1459
+ f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram."
1460
+ )
1461
+ ep.state_dict[k] = tensor
1462
+
1463
+ return ep
1464
+
1465
+ @disable_logging(log)
1466
+ def explain(self, print_output=True):
1467
+ blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1468
+
1469
+ graph_converter = ExplainTS2FXGraphConverter(
1470
+ self.ts_graph,
1471
+ self.name_to_param,
1472
+ self.name_to_buffer,
1473
+ blocks_to_lifted_attrs,
1474
+ self.name_to_non_tensor_attributes,
1475
+ self.name_to_constant,
1476
+ )
1477
+ graph_converter.explain()
1478
+ if len(graph_converter.unsupported_node_list) > 0:
1479
+ explain_str = "Unsupported nodes are found in the following list:"
1480
+ for i, n in enumerate(graph_converter.unsupported_node_list):
1481
+ node_str = "".join(str(n).split("\n")[:1])
1482
+ explain_str += f"\n\n {i}. {n.kind()} [{node_str}]"
1483
+ else:
1484
+ explain_str = "Success!"
1485
+ if print_output:
1486
+ print(explain_str)
1487
+ return explain_str
1488
+
1489
+ def retrace_as_exported_program(
1490
+ self,
1491
+ gm: torch.fx.GraphModule,
1492
+ name_to_constant: Dict[str, Any],
1493
+ ):
1494
+ # TODO: adjust input orders to match GraphSignature convention
1495
+ ep = torch.export._trace._export(
1496
+ gm,
1497
+ self.sample_args,
1498
+ strict=False,
1499
+ pre_dispatch=True,
1500
+ )
1501
+
1502
+ # Post-processing to make sure the ExportedProgram states are correct.
1503
+ # Because during conversion, we set tensor constants as GetAttr,
1504
+ # retracing cannot recognize them as tensor constants but instead
1505
+ # treat them as buffers. We need to set them again here.
1506
+ ep._constants.update(
1507
+ {
1508
+ k: v
1509
+ for k, v in name_to_constant.items()
1510
+ if isinstance(v, (torch.Tensor, torch.ScriptObject))
1511
+ }
1512
+ )
1513
+ for k in name_to_constant:
1514
+ ep.state_dict.pop(k, None)
1515
+
1516
+ for i, spec in enumerate(ep.graph_signature.input_specs):
1517
+ # Mark as constant tensors for erroneously traced buffers.
1518
+ if spec.kind == InputKind.BUFFER and spec.target in name_to_constant:
1519
+ assert isinstance(
1520
+ name_to_constant[spec.target], torch.Tensor
1521
+ ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
1522
+ spec.kind = InputKind.CONSTANT_TENSOR
1523
+ ep.verifier().check(ep)
1524
+
1525
+ return ep
1526
+
1527
+ def lift_get_attr(self):
1528
+ # This function lifts multiple data types.
1529
+
1530
+ # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3]))
1531
+ # to buffers. Currently, when there are tensor constants, export
1532
+ # would error and ask users to register tensor constants as buffers.
1533
+ # Since it is hard to manually do so for TorchScript models
1534
+ # (e.g., source code is missing), this function automatically
1535
+ # lifts tensor constants to be buffers.
1536
+
1537
+ # 2. ScriptObbject to constant. It will then be converted to getattr in
1538
+ # in the fx graph.
1539
+ #
1540
+ # This function should happen in TS2EPConverter instead of
1541
+ # TS2FXGraphConverter since it gets attributes from self.ts_model
1542
+ # which is not accessable in TS2FXGraphConverter. It is similar to where
1543
+ # we collect self.name_to_param and self.name_to_buffer.
1544
+ name_to_attribute_fqn: Dict[str, str] = {}
1545
+
1546
+ def get_attr(fqn: str):
1547
+ name = fqn.split(".")
1548
+ v = self.ts_model
1549
+ for n in name:
1550
+ v = getattr(v, n)
1551
+ return v
1552
+
1553
+ def get_fqn(node: torch._C.Node):
1554
+ attr_name = node.s("name")
1555
+ input_name = node.input().debugName()
1556
+ root_attr_name = name_to_attribute_fqn[input_name]
1557
+ attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
1558
+ return attr_fqn
1559
+
1560
+ def _dfs_get_attr(block):
1561
+ for node in block.nodes():
1562
+ if node.kind() == "prim::CreateObject":
1563
+ output_name = node.output().debugName()
1564
+ name_to_attribute_fqn[output_name] = ""
1565
+
1566
+ if node.kind() == "prim::GetAttr":
1567
+ attr_fqn = get_fqn(node)
1568
+ value = get_attr(attr_fqn)
1569
+ output_name = node.output().debugName()
1570
+ name_to_attribute_fqn[output_name] = attr_fqn
1571
+ if isinstance(value, torch.Tensor):
1572
+ if attr_fqn not in self.name_to_buffer:
1573
+ # Lift tensor constants to be a buffer
1574
+ self.name_to_buffer[attr_fqn] = value
1575
+ elif isinstance(value, torch.ScriptObject):
1576
+ if attr_fqn not in self.name_to_constant:
1577
+ self.name_to_constant[attr_fqn] = value
1578
+ else:
1579
+ self.name_to_non_tensor_attributes[attr_fqn] = value
1580
+
1581
+ for subblock in node.blocks():
1582
+ _dfs_get_attr(subblock)
1583
+
1584
+ _dfs_get_attr(self.ts_graph)
.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ import inspect
4
+ import logging
5
+ from collections import defaultdict
6
+ from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
7
+
8
+ import torch
9
+ import torch.utils._pytree as pytree
10
+ from torch._dynamo.source import (
11
+ AttrSource,
12
+ GetItemSource,
13
+ LocalSource,
14
+ TensorProperty,
15
+ TensorPropertySource,
16
+ )
17
+ from torch._dynamo.variables.builder import TrackedFake
18
+ from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
19
+ from torch._export.passes.lift_constants_pass import ConstantAttrMap
20
+ from torch._guards import Source
21
+ from torch._library.fake_class_registry import FakeScriptObject
22
+ from torch._subclasses.fake_tensor import FakeTensorMode
23
+ from torch.export import Constraint
24
+ from torch.export.dynamic_shapes import (
25
+ _check_dynamic_shapes,
26
+ _combine_args,
27
+ _DimHint,
28
+ _process_dynamic_shapes,
29
+ _transform_shapes_for_default_dynamic,
30
+ _tree_map_with_path,
31
+ )
32
+ from torch.export.graph_signature import CustomObjArgument
33
+ from torch.fx.experimental import _config as config
34
+ from torch.fx.experimental.symbolic_shapes import (
35
+ _find_user_code_frame,
36
+ _suggest_fixes_for_data_dependent_error_non_strict,
37
+ ConstraintViolationError,
38
+ DimDynamic,
39
+ EqualityConstraint,
40
+ GuardOnDataDependentSymNode,
41
+ ShapeEnv,
42
+ StatelessSymbolicContext,
43
+ ValueRanges,
44
+ )
45
+ from torch.utils._pytree import (
46
+ GetAttrKey,
47
+ KeyPath,
48
+ MappingKey,
49
+ SequenceKey,
50
+ tree_map_with_path,
51
+ )
52
+
53
+
54
+ if TYPE_CHECKING:
55
+ from sympy import Symbol
56
+
57
+
58
+ log = logging.getLogger(__name__)
59
+
60
+
61
+ def key_path_to_source(kp: KeyPath) -> Source:
62
+ """
63
+ Given a key path, return the source for the key path.
64
+ """
65
+ source: Source = LocalSource("args")
66
+ for k in kp:
67
+ if isinstance(k, SequenceKey):
68
+ source = GetItemSource(source, k.idx)
69
+ elif isinstance(k, MappingKey):
70
+ source = GetItemSource(source, k.key)
71
+ elif isinstance(k, GetAttrKey):
72
+ source = AttrSource(source, k.name)
73
+ else:
74
+ raise ValueError(f"Unknown KeyEntry {k}")
75
+
76
+ return source
77
+
78
+
79
+ def _is_constant_argument(t):
80
+ return t is None or isinstance(t, (int, float, bool, str))
81
+
82
+
83
+ def fakify(
84
+ mode: FakeTensorMode,
85
+ kp: KeyPath,
86
+ t: Any,
87
+ t_constraints: Dict[int, Dict[int, Constraint]],
88
+ sources: Dict[Tuple[int, int], List[Source]],
89
+ ):
90
+ source = key_path_to_source(kp)
91
+ if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
92
+ return t
93
+
94
+ if not isinstance(t, torch.Tensor):
95
+ raise ValueError(f"Unsupported input type {type(t)}")
96
+ n_dims = len(t.shape)
97
+ symbolic_context = StatelessSymbolicContext(
98
+ dynamic_sizes=[DimDynamic.DYNAMIC] * n_dims,
99
+ constraint_sizes=[None] * n_dims,
100
+ )
101
+ t_id = id(t)
102
+ assert mode.shape_env is not None
103
+ if t_id in t_constraints:
104
+ for i, constraint in t_constraints[t_id].items():
105
+ symbolic_context.constraint_sizes[i] = constraint.constraint_range
106
+ src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
107
+ sources[(t_id, i)].append(src)
108
+ mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment]
109
+ fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
110
+ mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
111
+ return fake
112
+
113
+
114
+ def make_fake_inputs(
115
+ nn_module,
116
+ args,
117
+ kwargs,
118
+ dynamic_shapes,
119
+ _is_torch_jit_trace=False,
120
+ allow_complex_guards_as_runtime_asserts=False,
121
+ ):
122
+ """
123
+ Given an nn module, example inputs, and constraints, return a new fake mode,
124
+ fake inputs created in that mode whose dynamic shape dimensions are constrained
125
+ by the given ranges, and sources for pairs of dynamic shape dimensions that are
126
+ constrained to be equal.
127
+ """
128
+ # TODO(avik): refactor Dynamo to avoid duplication of the following code
129
+ # between non-strict and strict.
130
+ # Specifically, here (non-strict) we do the following pre-tracing steps:
131
+ # - Fakify inputs.
132
+ # - Process input shape equalities.
133
+ # In strict, these steps are spread across multiple files:
134
+ # - output_graph.py fakifies inputs.
135
+ # - [post-tracing] guards.py processes input shape equalities.
136
+
137
+ combined_args = _combine_args(nn_module, args, kwargs)
138
+ _check_dynamic_shapes(combined_args, dynamic_shapes)
139
+ transformed_dynamic_shapes = _transform_shapes_for_default_dynamic(
140
+ combined_args, dynamic_shapes
141
+ )
142
+ constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes)
143
+ t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
144
+ for constraint in constraints:
145
+ t_constraints[constraint.t_id][constraint.dim] = constraint
146
+
147
+ context = torch._guards.TracingContext.try_get()
148
+ if context is not None:
149
+ # This occurs when we are exporting within dynamo. There already exists
150
+ # a toplevel TracingContext with a fake mode, so we do not want to
151
+ # create another fake mode.
152
+ fake_mode = context.fake_mode
153
+ elif not _is_torch_jit_trace:
154
+ code = nn_module.forward.__code__
155
+ co_fields = {
156
+ "co_name": code.co_name,
157
+ "co_filename": code.co_filename,
158
+ "co_firstlineno": code.co_firstlineno,
159
+ }
160
+ fake_mode = FakeTensorMode(
161
+ shape_env=ShapeEnv(
162
+ tracked_fakes=[],
163
+ co_fields=co_fields,
164
+ prefer_deferred_runtime_asserts_over_guards=True,
165
+ allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
166
+ ),
167
+ allow_non_fake_inputs=True,
168
+ export=True,
169
+ )
170
+ else:
171
+ fake_mode = FakeTensorMode(
172
+ shape_env=ShapeEnv(
173
+ tracked_fakes=[],
174
+ prefer_deferred_runtime_asserts_over_guards=True,
175
+ allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
176
+ ),
177
+ allow_non_fake_inputs=True,
178
+ )
179
+ if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
180
+ raise ValueError(
181
+ "Detected fake_mode does not have a shape_env with tracked fakes. "
182
+ "If you constructed the module under a FakeTensorMode, "
183
+ "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
184
+ )
185
+
186
+ with fake_mode:
187
+ # FIXME(ycao) ScriptMethod doesn't have signature, I am using an empty one to unblock
188
+ if not _is_torch_jit_trace:
189
+ original_signature = inspect.signature(nn_module.forward)
190
+ else:
191
+ original_signature = None
192
+ sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
193
+ fake_args, fake_kwargs = tree_map_with_path(
194
+ lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
195
+ (args, kwargs),
196
+ )
197
+
198
+ names: Dict[str, Tuple[int, int]] = {}
199
+ source_pairs: List[Tuple[Source, Source]] = []
200
+ derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
201
+ phantom_symbols: Dict[str, Symbol] = {}
202
+ for constraint in constraints:
203
+ torch.export.dynamic_shapes._process_equalities(
204
+ constraint,
205
+ lambda t_id, dim: sources[(t_id, dim)],
206
+ fake_mode.shape_env,
207
+ names,
208
+ source_pairs,
209
+ derived_equalities,
210
+ phantom_symbols,
211
+ )
212
+
213
+ equalities_inputs = EqualityConstraint(
214
+ source_pairs=source_pairs,
215
+ derived_equalities=derived_equalities,
216
+ phantom_symbols=list(phantom_symbols.values()),
217
+ warn_only=False,
218
+ )
219
+ return (
220
+ fake_mode,
221
+ fake_args,
222
+ fake_kwargs,
223
+ equalities_inputs,
224
+ original_signature,
225
+ transformed_dynamic_shapes,
226
+ )
227
+
228
+
229
+ def _flatten_dynamic_shapes(
230
+ combined_args: Dict[str, Any],
231
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
232
+ ) -> List[Any]:
233
+ flat_shapes = []
234
+
235
+ def _tree_map_helper(path, t, shape):
236
+ nonlocal flat_shapes
237
+ flat_shapes.append(shape)
238
+
239
+ _tree_map_with_path(_tree_map_helper, combined_args, dynamic_shapes)
240
+ return flat_shapes
241
+
242
+
243
+ def produce_guards_and_solve_constraints(
244
+ fake_mode: FakeTensorMode,
245
+ gm: torch.fx.GraphModule,
246
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
247
+ equalities_inputs: EqualityConstraint,
248
+ original_signature: inspect.Signature,
249
+ _is_torch_jit_trace=False,
250
+ ):
251
+ """
252
+ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
253
+ and a graph module, produce guards on the fake mode's shape env (raising constraint
254
+ violations if any), solve (to suggest simplifications or fixes).
255
+ Dynamo already performs this, so this is for non-strict mode.
256
+
257
+ Additional inputs:
258
+ equalities_inputs: the equality constraints to use for guards
259
+ original_signature: the signature of the forward method
260
+ """
261
+ shape_env = fake_mode.shape_env
262
+ assert shape_env is not None
263
+ assert shape_env.tracked_fakes is not None
264
+
265
+ placeholders = [tf.fake for tf in shape_env.tracked_fakes]
266
+ sources = [tf.source for tf in shape_env.tracked_fakes]
267
+ input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
268
+ constraint_violation_error = None
269
+ try:
270
+ shape_env.produce_guards(
271
+ placeholders,
272
+ sources,
273
+ input_contexts=input_contexts,
274
+ equalities_inputs=equalities_inputs,
275
+ ignore_static=False,
276
+ )
277
+ except ConstraintViolationError as e:
278
+ constraint_violation_error = e
279
+
280
+ shape_env.frozen = True
281
+ dim_constraints = shape_env.dim_constraints
282
+ if dim_constraints is None:
283
+ # Expected when shape_env.produce_guards throws an early constraint violation error.
284
+ # There is nothing to solve for in this case.
285
+ # TODO(avik): Maybe record the constraint violation error instead and replay later?
286
+ assert constraint_violation_error
287
+ raise constraint_violation_error
288
+ dim_constraints.solve()
289
+ forced_specializations = dim_constraints.forced_specializations()
290
+ if not _is_torch_jit_trace:
291
+ msg = dim_constraints.prettify_results(
292
+ original_signature,
293
+ dynamic_shapes,
294
+ constraint_violation_error,
295
+ forced_specializations,
296
+ )
297
+ else:
298
+ # FIXME(ycao): This is a hack to get around missing signature from ScriptMethod
299
+ msg = "dummy constraint violation message"
300
+ if constraint_violation_error:
301
+ constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
302
+ elif forced_specializations:
303
+ constraint_violation_error = ConstraintViolationError(msg)
304
+ if constraint_violation_error:
305
+ raise constraint_violation_error
306
+
307
+
308
+ def make_constraints(
309
+ fake_mode: FakeTensorMode,
310
+ gm: torch.fx.GraphModule,
311
+ combined_args: Dict[str, Any],
312
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
313
+ num_lifted_inputs: int,
314
+ ):
315
+ """
316
+ Given a fake mode's shape env and user-specified dynamic shapes,
317
+ return the resulting range constraints and equality constraints.
318
+
319
+ Additional args:
320
+ num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
321
+ (used only to enumerate the user-input nodes)
322
+ """
323
+
324
+ shape_env = fake_mode.shape_env
325
+ assert shape_env is not None
326
+ inline_constraints = gm.meta.get("inline_constraints", [])
327
+ range_constraints = {
328
+ symbol: inline_constraints[symbol] for symbol in inline_constraints
329
+ }
330
+ if not dynamic_shapes:
331
+ return range_constraints
332
+
333
+ # get individual dynamic shapes spec for each input
334
+ if not isinstance(dynamic_shapes, dict):
335
+ assert isinstance(dynamic_shapes, (tuple, list))
336
+ combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
337
+ flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)
338
+
339
+ # check number of shapes vs. number of inputs
340
+ num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
341
+ assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs
342
+
343
+ input_dims = defaultdict(list)
344
+ free_symbols = set()
345
+ for input_index, node in enumerate(gm.graph.nodes):
346
+ if input_index < num_lifted_inputs or node.op != "placeholder":
347
+ continue
348
+ if _is_constant_argument(node.meta["val"]) or isinstance(
349
+ node.meta["val"], CustomObjArgument
350
+ ):
351
+ continue
352
+ shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
353
+ for i, d in enumerate(node.meta["val"].shape):
354
+ if isinstance(d, torch.SymInt) and not d.node.expr.is_number:
355
+ # Look up the range constraint for the symbol corresponding to this shape dimension
356
+ # and store it indexed by the symbolic expression corresponding to it.
357
+ # NOTE(avik): Use node._expr instead of node.expr for the lookup here because
358
+ # we want the symbol, not its replacement, which could be an expression. Maybe
359
+ # there's a better way to do this, e.g., by (re)computing value ranges for expressions?
360
+ dim = shape_spec[i] if shape_spec else None
361
+ if dim is None or isinstance(dim, _DimHint):
362
+ range_constraints[d.node.expr] = shape_env.var_to_range[
363
+ d.node._expr
364
+ ]
365
+ else:
366
+ range_constraints[d.node.expr] = ValueRanges(
367
+ lower=dim.min, upper=dim.max
368
+ )
369
+ input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
370
+ free_symbols.update(d.node.expr.free_symbols)
371
+
372
+ for symbol in free_symbols:
373
+ if symbol not in range_constraints:
374
+ # Placeholders can have symbolic shapes that are derived expressions.
375
+ # The above code will record direct range constraints for them
376
+ # so that we can do runtime assertions. In addition, for serde checks
377
+ # we want to record range constraints for their root symbols.
378
+ range_constraints[symbol] = shape_env.var_to_range[symbol]
379
+
380
+ return range_constraints
381
+
382
+
383
+ def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap:
384
+ """Search the module hierarchy, gathering up all tensor and ScriptObject constants.
385
+
386
+ Returns a dictionary mapping hash(value) to the name of the constant. We
387
+ have to abuse `hash` here unfortunately, see: [ScriptObject hash].
388
+ """
389
+ constants = ConstantAttrMap()
390
+ buffers_parameters = set(m.buffers())
391
+ buffers_parameters.update(m.parameters())
392
+
393
+ def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
394
+ for k, v in m.__dict__.items():
395
+ if isinstance(
396
+ v,
397
+ (
398
+ torch.Tensor,
399
+ torch.ScriptObject,
400
+ FakeScriptObject,
401
+ ),
402
+ ):
403
+ if v in buffers_parameters:
404
+ # filter out buffers and parameters, leaving only constants
405
+ continue
406
+
407
+ fqn = ".".join(prefix_atoms + [k])
408
+ constants.add(v, fqn)
409
+ for k, v in m.named_children():
410
+ inner(v, prefix_atoms + [k], constants)
411
+
412
+ inner(m, [], constants)
413
+ return constants
414
+
415
+
416
+ @contextlib.contextmanager
417
+ def _fakify_script_objects(
418
+ mod: torch.nn.Module,
419
+ args: Tuple[Any],
420
+ kwargs: Dict[Any, Any],
421
+ fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
422
+ ):
423
+ # This context manager is used to fakify script objects into FakeScriptObject.
424
+ # Inputs:
425
+ # mod: the module to be exported, it (and its recursive submodules)'s script object attrs haven't been fakified.
426
+ # args, kwargs: the args and kwargs inputs for mod, script object inputs haven't been fakified.
427
+ # fake_mode: the fake mode to be used for fakifying script objects. It's the same mode that fakify input tensors.
428
+ #
429
+ # Returns:
430
+ # mod: the patched module, its (and its recursive submodules) script object attrs have been fakified.
431
+ # fake_args, fake_kwargs: new fakified args and kwargs.
432
+ # Script object inputs have been fakified. Don't touch the tensors.
433
+ # fake_constant_attrs: a new map from FakeScriptObject to the fqn of the original script object.
434
+ # fake_to_real: a mapping between FakeScriptObject and the original script object in order to un-do the patching.
435
+
436
+ constant_attrs: ConstantAttrMap = _gather_constant_attrs(mod)
437
+ assert not any(
438
+ isinstance(obj, FakeScriptObject) for obj in constant_attrs.values()
439
+ ), "Mod shouldn't contain any FakeScriptObject."
440
+ assert not pytree.tree_any(
441
+ lambda obj: isinstance(obj, FakeScriptObject), (args, kwargs)
442
+ ), "args and kwargs shouldn't contain any FakeScriptObject."
443
+
444
+ patched_attr = {}
445
+ fake_constant_attrs = ConstantAttrMap()
446
+ fake_to_real = {}
447
+
448
+ def _maybe_fakify_obj(obj):
449
+ fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, obj)
450
+ fake_to_real[fake_obj] = obj
451
+ return fake_obj
452
+
453
+ def _leaf_mod_and_attr(
454
+ mod: torch.nn.Module, attr_fqn: str
455
+ ) -> Tuple[torch.nn.Module, str]:
456
+ *prefix_attr, last_attr = attr_fqn.split(".")
457
+ cur_mod = mod
458
+ for attr in prefix_attr:
459
+ cur_mod = getattr(cur_mod, attr)
460
+ return cur_mod, last_attr
461
+
462
+ try:
463
+ for obj, fqns in constant_attrs.items():
464
+ if isinstance(obj, torch.ScriptObject):
465
+ fake_script_obj = _maybe_fakify_obj(obj)
466
+ for fqn in fqns:
467
+ cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
468
+ assert obj is getattr(cur_mod, attr)
469
+ setattr(cur_mod, attr, fake_script_obj)
470
+ fake_constant_attrs.add(fake_script_obj, fqn)
471
+ patched_attr[fqn] = obj
472
+ else:
473
+ for fqn in fqns:
474
+ fake_constant_attrs.add(obj, fqn)
475
+
476
+ fake_args, fake_kwargs = pytree.tree_map_only(
477
+ torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
478
+ )
479
+ yield (mod, fake_args, fake_kwargs, fake_constant_attrs, fake_to_real)
480
+ finally:
481
+ for fqn, orig_obj in patched_attr.items():
482
+ cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
483
+ setattr(cur_mod, attr, orig_obj)
484
+
485
+
486
+ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
487
+ """
488
+ 1. Handles data-dependent errors raised by torch function calls in non-strict.
489
+
490
+ Any data-dependent error is due to some condition on unbacked symints
491
+ that cannot be resolved. A mechanical way of fixing the error is to use
492
+ a torch._check() call to assert either that condition or its negation.
493
+ The handler suggests these options as code and points to the location
494
+ of the torch function call that raised the error as part of the error
495
+ message shown to the user, who can then simply select and copy-paste
496
+ a suggested fix at that location.
497
+
498
+ NOTE: Not all data-dependent errors are raised by torch function calls.
499
+ In particular, conditions on unbacked symints can appear outside such
500
+ calls, and as such are not handled here.
501
+
502
+ 2. Handles line-of-code logging for each torch function call in non-strict.
503
+
504
+ Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
505
+ """
506
+
507
+ def __torch_function__(self, func, types, args=(), kwargs=None):
508
+ kwargs = kwargs or {}
509
+ if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
510
+ frame = _find_user_code_frame()
511
+ if frame is not None:
512
+ log.debug(
513
+ "%s called at %s:%s in %s",
514
+ func.__qualname__,
515
+ frame.f_code.co_filename,
516
+ frame.f_lineno,
517
+ frame.f_code.co_name,
518
+ )
519
+ try:
520
+ return func(*args, **kwargs)
521
+ except GuardOnDataDependentSymNode as e:
522
+ _suggest_fixes_for_data_dependent_error_non_strict(e)
523
+ raise
.venv/lib/python3.11/site-packages/torch/_export/pass_base.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import operator
3
+ import traceback
4
+ import typing
5
+ from contextlib import nullcontext
6
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
+
8
+ import torch
9
+ from functorch.experimental.control_flow import _unstack_pytree
10
+ from torch import fx
11
+ from torch._dispatch.python import enable_python_dispatcher
12
+ from torch._export.pass_infra.node_metadata import NodeMetadata
13
+ from torch._export.pass_infra.proxy_value import ProxyValue
14
+ from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
15
+ from torch._subclasses.fake_tensor import FakeTensorMode
16
+ from torch.fx import traceback as fx_traceback
17
+ from torch.fx.experimental.proxy_tensor import PythonKeyTracer
18
+ from torch.fx.graph import CodeGen
19
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
20
+ from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
21
+ from torch.utils import _pytree as pytree
22
+ from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
23
+
24
+
25
+ __all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
26
+
27
+
28
+ Argument = Any
29
+ Value = Any
30
+ Fn = Callable[..., Any]
31
+ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
32
+
33
+
34
+ _TORCH_SYM_OPS: Set[Callable] = {
35
+ torch.sym_int,
36
+ torch.sym_float,
37
+ torch.sym_ite,
38
+ torch.sym_max,
39
+ torch.sym_min,
40
+ torch.sym_not,
41
+ torch.sym_sqrt,
42
+ }
43
+
44
+
45
+ class ExportPassBaseError(RuntimeError):
46
+ pass
47
+
48
+
49
+ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
50
+ """
51
+ Interpreter-based pass class to help users maintain the IR spec while writing
52
+ transformations.
53
+ """
54
+
55
+ @staticmethod
56
+ def _create_dummy_node_metadata():
57
+ return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
58
+
59
+
60
+ class ExportTracer(PythonKeyTracer):
61
+ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
62
+ super().__init__()
63
+ self.callback = callback
64
+ self.root = torch.nn.Module()
65
+ self.graph = torch.fx.Graph()
66
+ self.graph.set_codegen(codegen)
67
+ self.tensor_attrs: Dict[str, torch.Tensor] = {} # type: ignore[assignment]
68
+ self.fake_tensor_mode: Optional[FakeTensorMode] = None
69
+ self.submodules: Dict[torch.nn.Module, str] = {}
70
+
71
+ def trace(self) -> None: # type: ignore[override]
72
+ raise ExportPassBaseError("ExportTracer doesn't support trace().")
73
+
74
+ def create_arg(self, a: Argument) -> torch.fx.Node:
75
+ if isinstance(a, torch.nn.Module):
76
+ if a not in self.submodules:
77
+ name_submodule = f"submodule_{len(self.submodules)}"
78
+ self.root.add_module(name_submodule, a)
79
+ self.submodules[a] = name_submodule
80
+ elif isinstance(a, FakeTensor):
81
+ if not hasattr(a, "constant") or a.constant is None:
82
+ raise ExportPassBaseError(f"Cannot add {a} to graph.")
83
+ a = a.constant
84
+ node = super().create_arg(a)
85
+ if (
86
+ isinstance(a, torch.Tensor)
87
+ and isinstance(node, torch.fx.Node)
88
+ and node.op == "get_attr"
89
+ ):
90
+ self.set_metadata(node, a)
91
+ self.callback.on_attr(ProxyValue(a, node))
92
+ return node
93
+
94
+ def set_metadata(
95
+ self, node: torch.fx.Node, value: Argument,
96
+ ) -> None:
97
+ # propagate the fake tensor or sym nodes
98
+ def make_val(
99
+ x: Argument,
100
+ ) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
101
+ if isinstance(x, FakeTensor):
102
+ return x
103
+ elif isinstance(x, torch.Tensor):
104
+ if x.is_quantized:
105
+ # TODO (tmanlaibaatar) properly support Quantized FakeTensor
106
+ x = torch.dequantize(x)
107
+
108
+ try:
109
+ assert self.fake_tensor_mode is not None
110
+ # TODO we should allocate static shapes
111
+ # for param/buffer values
112
+ if isinstance(x, torch.nn.Parameter):
113
+ fake_tensor = self.fake_tensor_mode.from_tensor(
114
+ x, static_shapes=True
115
+ )
116
+ else:
117
+ fake_tensor = self.fake_tensor_mode.from_tensor(x)
118
+ except UnsupportedFakeTensorException:
119
+ # TODO: This is just a workaround to get over the
120
+ # x.as_subclass error
121
+ print(
122
+ "Fakeifying a Tensor subclass is not supported \
123
+ right now. Instead a TensorMetadata is used."
124
+ )
125
+ fake_tensor = None
126
+ return fake_tensor
127
+ elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
128
+ return x
129
+ else:
130
+ return None
131
+
132
+ node.meta["val"] = pytree.tree_map(make_val, value)
133
+
134
+ # Set the tensor_metadata for values that do not have a corresponding FakeTensor
135
+ def make_tensor_meta(x: Argument) -> Optional[TensorMetadata]:
136
+ if not isinstance(x, FakeTensor) and isinstance(x, torch.Tensor):
137
+ if x.is_quantized:
138
+ # TODO (tmanlaibaatar) properly support Quantized FakeTensor
139
+ x = torch.dequantize(x)
140
+
141
+ try:
142
+ assert self.fake_tensor_mode is not None
143
+ _ = self.fake_tensor_mode.from_tensor(x)
144
+ tensor_meta = None
145
+ except UnsupportedFakeTensorException:
146
+ # TODO: This is just a workaround to get over the
147
+ # x.as_subclass error
148
+ tensor_meta = _extract_tensor_metadata(x)
149
+ return tensor_meta
150
+ else:
151
+ return None
152
+
153
+ node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
154
+
155
+ class ExportInterpreter(fx.Interpreter):
156
+ def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
157
+ super().__init__(gm)
158
+ self.callback = callback
159
+ self.node: torch.fx.Node = next(iter(gm.graph.nodes))
160
+
161
+ def placeholder(
162
+ self,
163
+ target: str, # type: ignore[override]
164
+ args: Tuple[Argument, ...],
165
+ kwargs: Dict[str, Argument],
166
+ ) -> ProxyValue:
167
+ arg = super().placeholder(target, args, kwargs)
168
+ return self.callback.placeholder(target, arg, NodeMetadata(self.node.meta))
169
+
170
+ def output(
171
+ self,
172
+ target: torch.fx.node.Target,
173
+ args: Tuple[Argument, ...],
174
+ kwargs: Dict[str, Argument],
175
+ ) -> ProxyValue:
176
+ return self.callback.output(args[0], NodeMetadata(self.node.meta)).data
177
+
178
+ def call_function(
179
+ self,
180
+ target: torch.fx.node.Target,
181
+ args: Tuple[Argument, ...],
182
+ kwargs: Dict[str, Argument],
183
+ ) -> ProxyValue:
184
+ meta = NodeMetadata(self.node.meta)
185
+
186
+ if target == operator.getitem:
187
+ value, key = args
188
+ return self.callback.call_getitem(value, key, meta)
189
+ elif getattr(target, "__module__", None) in {"_operator", "math"}:
190
+ assert callable(target)
191
+ return self.callback.call_sym(target, args, meta)
192
+ elif target in _TORCH_SYM_OPS:
193
+ assert callable(target)
194
+ return self.callback.call_sym(target, args, meta)
195
+ elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
196
+ return self.callback.call_operator(
197
+ target,
198
+ args,
199
+ kwargs,
200
+ meta,
201
+ )
202
+ elif target == torch.ops.higher_order.cond:
203
+ pred, true_fn, false_fn, inputs = args
204
+ return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
205
+ elif target == torch.ops.higher_order.map_impl:
206
+ f, mapped_args, operands = args # type: ignore[assignment]
207
+ return self.callback.call_map(f, mapped_args, operands, meta)
208
+ # For other unregistered HigherOrderOps, just interpret them blindly
209
+ elif isinstance(target, torch._ops.HigherOrderOperator):
210
+ return self.callback._fx(
211
+ "call_function",
212
+ target,
213
+ args,
214
+ kwargs,
215
+ meta,
216
+ )
217
+ else:
218
+ raise ExportPassBaseError(f"Unsupported target type: {target}")
219
+
220
+ def get_attr(
221
+ self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
222
+ ) -> Argument:
223
+ return super().get_attr(target, args, kwargs)
224
+
225
+ def call_module(
226
+ self,
227
+ target: torch.fx.node.Target,
228
+ args: Tuple[Argument, ...],
229
+ kwargs: Dict[str, Argument],
230
+ ) -> None:
231
+ raise ExportPassBaseError("call_module is not supported.")
232
+
233
+ def call_method(
234
+ self, target: str, args: Tuple[Argument, ...], kwargs: Dict[str, Argument] # type: ignore[override]
235
+ ) -> None:
236
+ raise ExportPassBaseError("call_method is not supported.")
237
+
238
+ def run_node(self, n: torch.fx.Node) -> Argument:
239
+ self.node = n
240
+ self.callback.node_debug_str = n.format_node()
241
+ return super().run_node(n)
242
+
243
+ def __init__(self) -> None:
244
+ self.interpreter = PropagateUnbackedSymInts(
245
+ torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
246
+ )
247
+ self.tracer = self.ExportTracer(self, CodeGen())
248
+ self.fake_tensor_mode: Optional[FakeTensorMode] = None
249
+ self._initialized = True
250
+ self.node_debug_str: typing.Optional[str] = None
251
+
252
+ def _fx(
253
+ self,
254
+ kind: str,
255
+ target: torch.fx.node.Target,
256
+ args: Tuple[Argument, ...],
257
+ kwargs: Dict[str, Argument],
258
+ meta: NodeMetadata,
259
+ ) -> ProxyValue:
260
+ args_data, kwargs_data = pytree.tree_map_only(
261
+ ProxyValue, lambda x: x.data, (args, kwargs)
262
+ )
263
+ res_data = getattr(self.interpreter, kind)(target, args_data, kwargs_data)
264
+ args_proxy, kwargs_proxy = pytree.tree_map_only(
265
+ ProxyValue, lambda x: x.proxy, (args, kwargs)
266
+ )
267
+
268
+ name = None
269
+ if isinstance(target, torch._ops.OpOverload):
270
+ name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
271
+
272
+ res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
273
+ res_proxy.node.meta.update(meta.data)
274
+ if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
275
+ if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
276
+ res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
277
+ self.tracer.set_metadata(res_proxy.node, res_data)
278
+ return ProxyValue(res_data, res_proxy)
279
+
280
+ def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]:
281
+ # TODO(angelayi): Update this with what we decide to do for metadata in
282
+ # the exported graph module
283
+ if (args := graph_module.meta.get("args", None)) is not None:
284
+ return list(args)
285
+
286
+ def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
287
+ if "val" in node.meta:
288
+ fake = node.meta["val"]
289
+ if hasattr(fake, "constant") and fake.constant is not None:
290
+ return fake.constant
291
+ return fake
292
+ elif tensor_meta := node.meta.get("tensor_meta"):
293
+ assert self.fake_tensor_mode is not None
294
+ return FakeTensor(
295
+ self.fake_tensor_mode,
296
+ torch.empty(
297
+ tensor_meta.shape,
298
+ dtype=tensor_meta.dtype,
299
+ device="meta",
300
+ requires_grad=tensor_meta.requires_grad,
301
+ memory_format=tensor_meta.memory_format,
302
+ ),
303
+ torch.device("cpu"),
304
+ )
305
+ elif len(node.users) == 0:
306
+ return None
307
+ raise ExportPassBaseError(
308
+ f"Cannot construct an input for graph module: {graph_module}.",
309
+ )
310
+
311
+ return [
312
+ extract_input(node)
313
+ for node in graph_module.graph.nodes
314
+ if node.op == "placeholder"
315
+ ]
316
+
317
+ def on_attr(self, attr: ProxyValue) -> None:
318
+ pass
319
+
320
+ def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValue:
321
+ arg_proxy = self.tracer.create_proxy("placeholder", name, (), {})
322
+ arg_proxy.node.meta = meta.data
323
+ self.tracer.set_metadata(arg_proxy.node, arg)
324
+ return ProxyValue(arg, arg_proxy)
325
+
326
+ def call_operator(
327
+ self,
328
+ op,
329
+ args: Tuple[Argument, ...],
330
+ kwargs: Dict[str, Argument],
331
+ meta: NodeMetadata,
332
+ ) -> ProxyValue:
333
+ return self._fx("call_function", op, args, kwargs, meta)
334
+
335
+ def call_sym(
336
+ self,
337
+ target: Fn,
338
+ args: Tuple[Argument, ...],
339
+ meta: NodeMetadata,
340
+ ) -> ProxyValue:
341
+ return self._fx("call_function", target, args, {}, meta)
342
+
343
+ def call_cond(
344
+ self,
345
+ pred: ProxyValue,
346
+ true_fn: torch.fx.GraphModule,
347
+ false_fn: torch.fx.GraphModule,
348
+ inputs: List[Argument],
349
+ meta: NodeMetadata,
350
+ ) -> ProxyValue:
351
+ true_branch = self.call_submodule(true_fn, tuple(inputs))
352
+ false_branch = self.call_submodule(false_fn, tuple(inputs))
353
+ assert true_branch is not None
354
+ assert false_branch is not None
355
+ return self._fx(
356
+ "call_function",
357
+ torch.ops.higher_order.cond,
358
+ (pred, true_branch.graph_module, false_branch.graph_module, list(inputs)),
359
+ {},
360
+ meta,
361
+ )
362
+
363
+ def call_map(
364
+ self,
365
+ f: torch.fx.GraphModule,
366
+ mapped_args: List[ProxyValue],
367
+ operands: List[ProxyValue],
368
+ meta: NodeMetadata,
369
+ ) -> ProxyValue:
370
+ xs = _unstack_pytree([arg.data for arg in mapped_args])[0]
371
+ f_branch = self.call_submodule(f, tuple(xs + [arg.data for arg in operands]))
372
+ assert f_branch is not None
373
+ return self._fx(
374
+ "call_function",
375
+ torch.ops.higher_order.map_impl,
376
+ (f_branch.graph_module, mapped_args, operands),
377
+ {},
378
+ meta,
379
+ )
380
+
381
+ def call_getitem(
382
+ self, value: ProxyValue, key: int, meta: NodeMetadata
383
+ ) -> ProxyValue:
384
+ return self._fx("call_function", operator.getitem, (value, key), {}, meta)
385
+
386
+ def output(self, results: List[Argument], meta: NodeMetadata) -> ProxyValue:
387
+ return self._fx("output", "output", (results,), {}, meta)
388
+
389
+ def call_submodule(
390
+ self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...]
391
+ ) -> PassResult:
392
+ prev_tracer, self.tracer = self.tracer, self.ExportTracer(
393
+ self, graph_module.graph._codegen
394
+ )
395
+ self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
396
+ interpreter = self.ExportInterpreter(self, graph_module)
397
+ prev_interpreter, self.interpreter = self.interpreter, torch.fx.Interpreter( # type: ignore[assignment]
398
+ torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
399
+ )
400
+ inputs_data = pytree.tree_map_only(ProxyValue, lambda x: x.data, inputs)
401
+ with fx_traceback.preserve_node_meta():
402
+ interpreter.run(*inputs_data)
403
+
404
+ new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
405
+
406
+ self.tracer = prev_tracer
407
+ self.interpreter = prev_interpreter
408
+ return PassResult(
409
+ new_graph_module,
410
+ True,
411
+ )
412
+
413
+ def call(self, graph_module: fx.GraphModule) -> PassResult:
414
+ if not getattr(self, "_initialized", False):
415
+ raise ExportPassBaseError(
416
+ "ExportPass is not initialized with __init__().",
417
+ )
418
+
419
+ inputs = self.inputs(graph_module)
420
+
421
+ fake_tensor_mode = None
422
+ for i in inputs:
423
+ if isinstance(i, FakeTensor):
424
+ assert (
425
+ fake_tensor_mode is None or fake_tensor_mode is i.fake_mode
426
+ ), "Multiple fake tensor mode detected."
427
+ fake_tensor_mode = i.fake_mode
428
+ if fake_tensor_mode is None:
429
+ self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
430
+ fake_tensor_mode = nullcontext() # type: ignore[assignment]
431
+ dispatcher_mode = nullcontext() # type: ignore[assignment]
432
+ else:
433
+ fake_tensor_mode.allow_non_fake_inputs = True
434
+ self.tracer.fake_tensor_mode = fake_tensor_mode
435
+ dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
436
+ self.fake_tensor_mode = self.tracer.fake_tensor_mode
437
+
438
+ with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
439
+ result = self.call_submodule(graph_module, tuple(inputs))
440
+
441
+ return result
.venv/lib/python3.11/site-packages/torch/_export/tools.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import warnings
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.export
8
+ import torch.export._trace
9
+ from torch._utils_internal import log_export_usage
10
+
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ __all__ = ["report_exportability"]
15
+
16
+
17
+ def _generate_inputs_for_submodules(
18
+ model: torch.nn.Module,
19
+ target_submodules: Iterable[str],
20
+ args: Tuple[Any, ...],
21
+ kwargs: Optional[Dict[str, Any]] = None,
22
+ ) -> Dict[str, Tuple[Any, Any]]:
23
+ """
24
+ Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
25
+ function doesn't work.
26
+
27
+ Args:
28
+ model: root model.
29
+ inputs: inputs to the root model.
30
+ target_submodules: submodules that we want to generate inputs for.
31
+
32
+ Returns:
33
+ A dict that maps from submodule name to its inputs.
34
+ """
35
+ kwargs = kwargs or {}
36
+
37
+ handles = []
38
+ results = {}
39
+ submodule_to_names = {mod: name for name, mod in model.named_modules()}
40
+
41
+ def pre_forward(module, module_args, module_kwargs):
42
+ results[submodule_to_names[module]] = (module_args, module_kwargs)
43
+
44
+ try:
45
+ for name, mod in model.named_modules():
46
+ if name in target_submodules:
47
+ handles.append(
48
+ mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
49
+ )
50
+ model(*args, **kwargs)
51
+ except Exception as e:
52
+ warnings.warn(
53
+ f"Failed to generate submodule inputs because of the following error:\n{e}"
54
+ )
55
+ finally:
56
+ for h in handles:
57
+ h.remove()
58
+ return results
59
+
60
+
61
+ def report_exportability(
62
+ mod: torch.nn.Module,
63
+ args: Tuple[Any, ...],
64
+ kwargs: Optional[Dict[str, Any]] = None,
65
+ *,
66
+ strict: bool = True,
67
+ pre_dispatch: bool = False,
68
+ ) -> Dict[str, Optional[Exception]]:
69
+ """
70
+ Report exportability issues for a module in one-shot.
71
+
72
+ Args:
73
+ mod: root module.
74
+ args: args to the root module.
75
+ kwargs: kwargs to the root module.
76
+ Returns:
77
+ A dict that maps from submodule name to the exception that was raised when trying to export it.
78
+ `None` means the module is exportable without issue.
79
+ Sample output:
80
+ {
81
+ '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
82
+ 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
83
+ 'submod_2': None
84
+ }
85
+ """
86
+
87
+ log_export_usage(event="export.report_exportability")
88
+
89
+ kwargs = kwargs or {}
90
+
91
+ all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
92
+ submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
93
+
94
+ tried_module_types = set()
95
+ report: Dict[str, Optional[Exception]] = {}
96
+
97
+ def try_export(module, module_name, args, kwargs):
98
+ nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
99
+
100
+ if type(module) in tried_module_types:
101
+ return
102
+ tried_module_types.add(type(module))
103
+
104
+ if args is not None or kwargs is not None:
105
+ try:
106
+ torch.export._trace._export(
107
+ module,
108
+ args,
109
+ kwargs,
110
+ strict=strict,
111
+ pre_dispatch=pre_dispatch,
112
+ )
113
+ report[module_name] = None
114
+ log.info("Successfully exported `%s`", module_name)
115
+ return
116
+ except Exception as e:
117
+ short_msg = repr(e).split("\n")[0]
118
+ log.warning(
119
+ "Failed exporting `%s` with exception: %s", module_name, short_msg
120
+ )
121
+ report[module_name] = e
122
+
123
+ for name, submod in module.named_children():
124
+ sub_module_name = name if module_name == "" else f"{module_name}.{name}"
125
+
126
+ submod_args, submod_kwargs = submod_inputs.get(
127
+ sub_module_name, (None, None)
128
+ )
129
+
130
+ try_export(submod, sub_module_name, submod_args, submod_kwargs)
131
+
132
+ return
133
+
134
+ try_export(mod, "", args, kwargs)
135
+
136
+ unique_issues = set()
137
+ for exception in report.values():
138
+ if exception is not None:
139
+ key = repr(exception).split("\\n")[0]
140
+ unique_issues.add(key)
141
+
142
+ log.warning("Found %d export issues:", len(unique_issues))
143
+ for issue in unique_issues:
144
+ log.warning(issue)
145
+
146
+ return report
.venv/lib/python3.11/site-packages/torch/_export/verifier.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ import math
4
+ import operator
5
+ from collections.abc import Iterable
6
+ from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch._ops import HigherOrderOperator, OpOverload
10
+ from torch._subclasses.fake_tensor import FakeTensor
11
+ from torch.export.graph_signature import (
12
+ CustomObjArgument,
13
+ InputKind,
14
+ SymIntArgument,
15
+ TensorArgument,
16
+ TokenArgument,
17
+ )
18
+ from torch.fx import GraphModule
19
+
20
+ if TYPE_CHECKING:
21
+ from torch.export.exported_program import ExportedProgram
22
+
23
+ class SpecViolationError(Exception):
24
+ pass
25
+
26
+
27
+ def is_functional(op: OpOverload) -> bool:
28
+ return not op._schema.is_mutable
29
+
30
+
31
+ def _check_has_fake_tensor(node: torch.fx.Node) -> None:
32
+ # TODO(angelayi): remove this in favor of _check_val
33
+ return _check_val(node)
34
+
35
+
36
+ def _check_val(node: torch.fx.Node) -> None:
37
+ from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
38
+
39
+ def _check_correct_val(val):
40
+ if val is None:
41
+ return True
42
+ elif isinstance(val, (int, bool, str, float)):
43
+ return True
44
+ elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
45
+ return True
46
+ elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
47
+ return True
48
+ elif isinstance(val, (SymInt, SymFloat, SymBool)):
49
+ return True
50
+ elif isinstance(val, CustomObjArgument):
51
+ return True
52
+ elif isinstance(val, Iterable):
53
+ return all(_check_correct_val(x) for x in val)
54
+ return False
55
+
56
+ def _no_returns(op):
57
+ if not isinstance(op, OpOverload):
58
+ return False
59
+ return len(op._schema.returns) == 0
60
+
61
+ if "val" not in node.meta:
62
+ if node.op == "call_function" and _no_returns(node.target):
63
+ return
64
+ raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
65
+
66
+ val = node.meta["val"]
67
+ if not _check_correct_val(val):
68
+ raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
69
+
70
+
71
+ def _check_torch_fn(node: torch.fx.Node) -> None:
72
+ torch_fn = node.meta.get("torch_fn")
73
+ if torch_fn is None:
74
+ raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}")
75
+ if (
76
+ not isinstance(torch_fn, tuple) and
77
+ isinstance(torch_fn[0], str) and
78
+ isinstance(torch_fn[1], str)
79
+ ):
80
+ raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}")
81
+
82
+ class _VerifierMeta(type):
83
+ _registry: Dict[str, Type['Verifier']] = {}
84
+
85
+ def __new__(metacls, name, bases, attrs):
86
+ if bases:
87
+ if "check" in attrs or "_check_graph_module" in attrs:
88
+ raise SyntaxError("Overriding method check is not allowed.")
89
+ assert "dialect" in attrs and attrs["dialect"] != "ATEN"
90
+ else:
91
+ assert "check" in attrs
92
+ assert "_check_graph_module" in attrs
93
+ assert attrs["dialect"] == "ATEN"
94
+
95
+ assert isinstance(attrs["dialect"], str)
96
+ ret = type.__new__(metacls, name, bases, attrs)
97
+ metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
98
+ return ret
99
+
100
+ def getattr_recursive(obj: Any, target: str) -> Any:
101
+ target_atoms = target.split('.')
102
+ attr_itr = obj
103
+ for i, atom in enumerate(target_atoms):
104
+ if not hasattr(attr_itr, atom):
105
+ raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
106
+ attr_itr = getattr(attr_itr, atom)
107
+ return attr_itr
108
+
109
+
110
+ class Verifier(metaclass=_VerifierMeta):
111
+ dialect = "ATEN"
112
+
113
+ def allowed_builtin_ops(self) -> List:
114
+ return [
115
+ operator.getitem,
116
+ operator.add,
117
+ operator.mul,
118
+ operator.sub,
119
+ operator.truediv,
120
+ operator.ge,
121
+ operator.le,
122
+ operator.gt,
123
+ operator.lt,
124
+ operator.eq,
125
+ operator.ne,
126
+ operator.floordiv,
127
+ operator.mod,
128
+ operator.and_,
129
+ operator.or_,
130
+ operator.not_,
131
+ operator.pow,
132
+ operator.neg,
133
+ operator.abs,
134
+ math.ceil,
135
+ math.floor,
136
+ math.trunc,
137
+ ]
138
+
139
+ def allowed_op_types(self) -> Tuple[Type[Any], ...]:
140
+ return (OpOverload, HigherOrderOperator)
141
+
142
+ def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
143
+ return (torch.fx.GraphModule,)
144
+
145
+ def check_valid_op(self, op):
146
+ pass
147
+
148
+ def check_additional(self, gm: GraphModule) -> None:
149
+ """
150
+ Additional checks that are specific to some dialects.
151
+ """
152
+
153
+ @final
154
+ def check(self, ep: "ExportedProgram") -> None:
155
+ self._check_graph_module(ep.graph_module)
156
+ _verify_exported_program_module_call_graph(ep)
157
+ _verify_exported_program_signature(ep)
158
+
159
+ @final
160
+ def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
161
+ def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
162
+ ret = self.allowed_getattr_types()
163
+ assert not any(t is object for t in ret)
164
+ return ret
165
+
166
+ def _check_valid_op(op) -> None:
167
+ def _allowed_builtin_ops() -> List:
168
+ ret = self.allowed_builtin_ops()
169
+ assert all(inspect.isbuiltin(op) for op in ret)
170
+ return ret
171
+
172
+ def _allowed_op_types() -> Tuple[Type[Any], ...]:
173
+ ret = self.allowed_op_types()
174
+ assert not any(t is object for t in ret)
175
+ return ret
176
+
177
+ # TODO Remove this allowlist.
178
+ _allowed_torch_functions = (
179
+ torch.autograd.grad_mode.set_grad_enabled,
180
+ torch.sym_int,
181
+ torch.sym_float,
182
+ torch.sym_ite,
183
+ torch.sym_max,
184
+ torch.sym_min,
185
+ torch.sym_not,
186
+ torch.sym_sqrt,
187
+ # TODO (tmanlaibaatar)
188
+ # Predispatch export is able to contain autograd ops.
189
+ # These will be modeled as HOO later
190
+ torch._C._set_grad_enabled,
191
+ )
192
+
193
+ if not isinstance(op, _allowed_op_types()):
194
+ if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
195
+ raise SpecViolationError(
196
+ f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
197
+ f"Valid builtin ops: {_allowed_builtin_ops()}"
198
+ f"Valid torch functions: {_allowed_torch_functions}"
199
+ )
200
+
201
+ if isinstance(op, OpOverload):
202
+ # All ops functional
203
+ # TODO (tmanlaibaatar) more proper way is needed here
204
+ if self.dialect != "TRAINING" and not is_functional(op):
205
+ raise SpecViolationError(
206
+ f"operator '{op}' is not functional"
207
+ )
208
+ self.check_valid_op(op)
209
+
210
+ for mod in gm.modules():
211
+ if not isinstance(mod, torch.fx.GraphModule):
212
+ continue
213
+
214
+ mod.graph.lint()
215
+ for node in mod.graph.nodes:
216
+ # TODO(T140410192): should have fake tensor for all dialects
217
+ if node.op in {"call_module", "call_method"}:
218
+ raise SpecViolationError(
219
+ f"call_module is not valid: got a class '{node.target}' ",
220
+ )
221
+
222
+ elif node.op == "call_function":
223
+ _check_val(node)
224
+
225
+ _check_valid_op(node.target)
226
+
227
+ elif node.op == "get_attr":
228
+ if not isinstance(node.target, str):
229
+ raise SpecViolationError(
230
+ f"Expected get_attr target to be string, but got {type(node.target)}"
231
+ )
232
+
233
+ attr = getattr_recursive(mod, node.target)
234
+ if isinstance(attr, torch.nn.Module):
235
+ def _is_type(name, ty):
236
+ return isinstance(getattr(attr, name, None), ty)
237
+ if type(attr).__name__ == "LoweredBackendModule":
238
+ if _is_type("backend_id", str) \
239
+ and _is_type("processed_bytes", bytes) \
240
+ and _is_type("compile_specs", list) \
241
+ and hasattr(attr, "original_module"):
242
+ continue
243
+ else:
244
+ backend_id = getattr(attr, "backend_id", None)
245
+ processed_bytes = getattr(attr, "processed_bytes", None)
246
+ compile_specs = getattr(attr, "compile_specs", None)
247
+ raise SpecViolationError(
248
+ f"Invalid get_attr type {type(attr)}. \n"
249
+ f"LoweredBackendModule fields: "
250
+ f"backend_id(str) : {type(backend_id)}, "
251
+ f"processed_bytes(bytes) : {type(processed_bytes)}, "
252
+ f"compile_specs(list) : {type(compile_specs)}"
253
+ )
254
+
255
+ if not isinstance(attr, _allowed_getattr_types()):
256
+ raise SpecViolationError(
257
+ f"Invalid get_attr type {type(attr)}. \n"
258
+ f"Valid get_attr types: {_allowed_getattr_types()}"
259
+ )
260
+
261
+
262
+ elif node.op == "placeholder":
263
+ _check_val(node)
264
+ # TODO(zhxchen17)
265
+ # elif node.op == "output":
266
+ # _check_flattened_outputs()
267
+
268
+ self.check_additional(gm)
269
+
270
+
271
+ class TrainingIRVerifier(Verifier):
272
+ dialect = "TRAINING"
273
+
274
+
275
+ def _verify_exported_program_module_call_graph(exported_program) -> None:
276
+ module_call_graph = exported_program.module_call_graph
277
+ nodes = {
278
+ node.name for node in exported_program.graph.nodes
279
+ }
280
+ for entry in module_call_graph:
281
+ if entry.signature is not None:
282
+ for arg in entry.signature.inputs:
283
+ if arg.name and arg.name not in nodes:
284
+ raise SpecViolationError(
285
+ f"Input {arg.name} does not exist in the graph."
286
+ )
287
+ for arg in entry.signature.outputs:
288
+ if arg.name and arg.name not in nodes:
289
+ raise SpecViolationError(
290
+ f"Output {arg.name} does not exist in the graph."
291
+ )
292
+
293
+
294
+ def _verify_exported_program_signature(exported_program) -> None:
295
+ # Check ExportedProgram signature matches
296
+ gs = exported_program.graph_signature
297
+
298
+ # Check every node in the signature exists in the graph
299
+ input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
300
+
301
+ if len(input_node_names) != len(gs.input_specs):
302
+ raise SpecViolationError(
303
+ f"Number of graph inputs ({len(input_node_names)}) "
304
+ f"does not match number of inputs in the graph signature ({len(gs.input_specs)})"
305
+ )
306
+
307
+ for input_spec, node in zip(gs.input_specs, input_node_names):
308
+ if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
309
+ if input_spec.arg.name != node:
310
+ raise SpecViolationError(
311
+ f"Input spec name {input_spec.arg.name} does not match node name {node}"
312
+ )
313
+
314
+ if input_spec.kind == InputKind.USER_INPUT:
315
+ continue
316
+
317
+ elif input_spec.kind == InputKind.PARAMETER:
318
+ if not isinstance(input_spec.arg, TensorArgument):
319
+ raise SpecViolationError(
320
+ f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
321
+ )
322
+ if input_spec.target is None:
323
+ raise SpecViolationError(
324
+ f"InputSpec for {input_spec.name} has no target."
325
+ )
326
+
327
+ param = input_spec.target
328
+ if param not in exported_program.state_dict:
329
+ raise SpecViolationError(
330
+ f"Parameter {param} is not in the state dict."
331
+ )
332
+
333
+ if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
334
+ raise SpecViolationError(
335
+ f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
336
+ )
337
+
338
+ elif input_spec.kind == InputKind.BUFFER:
339
+ if not isinstance(input_spec.arg, TensorArgument):
340
+ raise SpecViolationError(
341
+ f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
342
+ )
343
+ if input_spec.target is None:
344
+ raise SpecViolationError(
345
+ f"InputSpec for {input_spec.name} has no target."
346
+ )
347
+
348
+ buffer = input_spec.target
349
+ if input_spec.persistent is None:
350
+ raise SpecViolationError(
351
+ f"Buffer {buffer} is missing a persistence flag"
352
+ )
353
+
354
+ if input_spec.persistent is True and buffer not in exported_program.state_dict:
355
+ raise SpecViolationError(
356
+ f"Buffer {buffer} is not in the state dict."
357
+ )
358
+
359
+ if input_spec.persistent is False and buffer in exported_program.state_dict:
360
+ raise SpecViolationError(
361
+ f"Non-persistent buffer {buffer} is in the state dict, it should not be."
362
+ )
363
+ elif input_spec.kind == InputKind.CONSTANT_TENSOR:
364
+ if not isinstance(input_spec.arg, TensorArgument):
365
+ raise SpecViolationError(
366
+ f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
367
+ )
368
+ if input_spec.target is None:
369
+ raise SpecViolationError(
370
+ f"InputSpec for {input_spec.name} has no target."
371
+ )
372
+
373
+ tensor_const = input_spec.target
374
+ if tensor_const not in exported_program.constants:
375
+ raise SpecViolationError(
376
+ f"Constant tensor {tensor_const} is not in the constants dictionary."
377
+ )
378
+ elif input_spec.kind == InputKind.CUSTOM_OBJ:
379
+ if not isinstance(input_spec.arg, CustomObjArgument):
380
+ raise SpecViolationError(
381
+ f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
382
+ )
383
+ if input_spec.target is None:
384
+ raise SpecViolationError(
385
+ f"InputSpec for {input_spec.name} has no target."
386
+ )
387
+
388
+ custom_obj = input_spec.target
389
+ if custom_obj not in exported_program.constants:
390
+ raise SpecViolationError(
391
+ f"Custom object {custom_obj} is not in the constants dictionary."
392
+ )
393
+ elif input_spec.kind == InputKind.TOKEN:
394
+ if not isinstance(input_spec.arg, TokenArgument):
395
+ raise SpecViolationError(
396
+ f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
397
+ )
398
+ else:
399
+ raise SpecViolationError(
400
+ f"Unknown InputKind {input_spec.kind}."
401
+ )
402
+
403
+ # Check outputs
404
+ output_node = list(exported_program.graph.nodes)[-1]
405
+ assert output_node.op == "output"
406
+ output_nodes = [
407
+ arg.name if isinstance(arg, torch.fx.Node) else arg
408
+ for arg in output_node.args[0]
409
+ ]
410
+
411
+ if len(output_nodes) != len(gs.output_specs):
412
+ raise SpecViolationError(
413
+ f"Number of output nodes {len(output_nodes)} is different "
414
+ "Than the number of outputs specified by the graph signature: \n"
415
+ f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
416
+ f"Number of user outputs: {len(gs.user_outputs)}. \n"
417
+ )
418
+
419
+ num_tokens = len(gs.output_tokens)
420
+ end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
421
+ mutate_nodes: List[str] = output_nodes[num_tokens:end]
422
+ user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
423
+
424
+ for mutation_node in mutate_nodes:
425
+ if mutation_node in gs.buffers_to_mutate:
426
+ if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
427
+ raise SpecViolationError(
428
+ f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
429
+ f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
430
+ f"Buffer nodes available: {gs.buffers} \n"
431
+ )
432
+ elif mutation_node in gs.user_inputs_to_mutate:
433
+ if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
434
+ raise SpecViolationError(
435
+ f"User input output {mutation_node} does not point to a user input that exists. \n"
436
+ f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
437
+ f"User input nodes available: {gs.user_inputs} \n")
438
+ else:
439
+ raise SpecViolationError(
440
+ f"Mutation node {mutation_node} is neither a buffer nor a user input. "
441
+ f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
442
+ )
443
+
444
+ for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
445
+ if user_output_node != user_output_name:
446
+ raise SpecViolationError(
447
+ f"User output {user_output_node} is not in the correct "
448
+ "order or is not found in the "
449
+ f"exported program's user_output list: {gs.user_outputs}. "
450
+ )
451
+
452
+
453
+ def load_verifier(dialect: str) -> Type[Verifier]:
454
+ if dialect == "ATEN" or dialect == "":
455
+ return _VerifierMeta._registry.get(dialect, Verifier)
456
+ return _VerifierMeta._registry[dialect]
.venv/lib/python3.11/site-packages/torch/_export/wrappers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from contextlib import contextmanager
3
+
4
+ import torch
5
+ import torch._custom_ops
6
+ from torch._C import DispatchKey
7
+ from torch._higher_order_ops.strict_mode import strict_mode
8
+ from torch._higher_order_ops.utils import autograd_not_implemented
9
+ from torch._ops import HigherOrderOperator
10
+ from torch._subclasses.fake_tensor import FakeTensorMode
11
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
12
+ from torch.utils import _pytree as pytree
13
+
14
+
15
+ class ExportTracepoint(HigherOrderOperator):
16
+ def __init__(self):
17
+ super().__init__("_export_tracepoint")
18
+
19
+ def __call__(self, *args, **kwargs):
20
+ return super().__call__(*args, **kwargs)
21
+
22
+
23
+ _export_tracepoint = ExportTracepoint()
24
+
25
+
26
+ @_export_tracepoint.py_impl(ProxyTorchDispatchMode)
27
+ def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
28
+ p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
29
+ proxy = mode.tracer.create_proxy(
30
+ "call_function", _export_tracepoint, p_args, p_kwargs
31
+ )
32
+ return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
33
+
34
+
35
+ @_export_tracepoint.py_impl(FakeTensorMode)
36
+ def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
37
+ with mode:
38
+ return args
39
+
40
+
41
+ @_export_tracepoint.py_functionalize_impl
42
+ def export_tracepoint_functional(ctx, *args, **kwargs):
43
+ unwrapped_args = ctx.unwrap_tensors(args)
44
+ unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
45
+
46
+ with ctx.redispatch_to_next():
47
+ out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
48
+ return ctx.wrap_tensors(out)
49
+
50
+
51
+ _export_tracepoint.py_impl(DispatchKey.Autograd)(
52
+ autograd_not_implemented(_export_tracepoint, deferred_error=True)
53
+ )
54
+
55
+
56
+ @_export_tracepoint.py_impl(DispatchKey.CPU)
57
+ def export_tracepoint_cpu(*args, **kwargs):
58
+ return args
59
+
60
+
61
+ def _wrap_submodule(mod, path, module_call_specs):
62
+ assert isinstance(mod, torch.nn.Module)
63
+ assert path != ""
64
+ submodule = mod
65
+ for name in path.split("."):
66
+ if not hasattr(submodule, name):
67
+ raise RuntimeError(f"Couldn't find submodule at path {path}")
68
+ submodule = getattr(submodule, name)
69
+
70
+ def update_module_call_signatures(path, in_spec, out_spec):
71
+ if path in module_call_specs:
72
+ assert module_call_specs[path]["in_spec"] == in_spec
73
+ assert module_call_specs[path]["out_spec"] == out_spec
74
+ module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
75
+
76
+ def check_flattened(flat_args):
77
+ for a in flat_args:
78
+ if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
79
+ raise AssertionError(
80
+ f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
81
+ )
82
+
83
+ def pre_hook(module, args, kwargs):
84
+ flat_args, in_spec = pytree.tree_flatten((args, kwargs))
85
+ check_flattened(flat_args)
86
+ flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
87
+ args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
88
+ return args, kwargs
89
+
90
+ def post_hook(module, args, kwargs, res):
91
+ _, in_spec = pytree.tree_flatten((args, kwargs))
92
+ flat_res, out_spec = pytree.tree_flatten(res)
93
+ check_flattened(flat_res)
94
+ flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
95
+ update_module_call_signatures(path, in_spec, out_spec)
96
+ return pytree.tree_unflatten(flat_res, out_spec)
97
+
98
+ pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
99
+ post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
100
+ return pre_handle, post_handle
101
+
102
+
103
+ @contextmanager
104
+ def _wrap_submodules(f, preserve_signature, module_call_signatures):
105
+ handles = []
106
+
107
+ try:
108
+ for path in preserve_signature:
109
+ handles.extend(_wrap_submodule(f, path, module_call_signatures))
110
+ yield
111
+ finally:
112
+ for handle in handles:
113
+ handle.remove()
114
+
115
+
116
+ def _mark_strict_experimental(cls):
117
+ def call(self, *args):
118
+ return strict_mode(self, args)
119
+
120
+ cls.__call__ = call
121
+ return cls
.venv/lib/python3.11/site-packages/torch/_lazy/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+
3
+ import torch._C._lazy
4
+ from torch.utils._pytree import tree_flatten, tree_unflatten
5
+
6
+ from .closure import add_step_closure, run_step_closures
7
+
8
+
9
+ def mark_step(device: str = "", wait=False):
10
+ """Triggers a mark step, which amounts to
11
+ - collecting a group of 'live' lazy tensors to index into the compilation cache
12
+ (lowering/compiling their IR graphs if not cached)
13
+ - kicking off execution of the compiled function
14
+ - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
15
+ """
16
+ # TODO(whc) expand this to include backend hooks and align with XLA backend needs
17
+ torch._C._lazy._mark_step(device, [], wait=wait)
18
+
19
+ run_step_closures()
20
+
21
+
22
+ def wait_device_ops(devices=None):
23
+ """Waits for all the async operations on the given devices to complete.
24
+ Args:
25
+ devices (string..., optional): The devices whose async ops need to be waited
26
+ for. If empty, all the local devices will be waited for.
27
+ """
28
+ if devices is None:
29
+ devices = []
30
+ torch._C._lazy._wait_device_ops(devices=devices)
31
+
32
+
33
+ def sync_multi(tensors, devices):
34
+ """
35
+ Sync the list of lazy tensors so there IR get lowered for the activate backend
36
+ and the compiled computation graph get cached.
37
+ """
38
+ torch._C._lazy._sync_multi(tensors, devices)
39
+
40
+
41
+ def get_tensor_id(tensor):
42
+ """Return a unique id of the lazy tensor maintained by LTC"""
43
+ return torch._C._lazy._get_tensor_id(tensor)
44
+
45
+
46
+ def to_cpu(tensors, devices=None):
47
+ devices = devices or ["lazy"]
48
+
49
+ flattened, spec = tree_flatten(tensors)
50
+ sync_multi(flattened, devices)
51
+ return tree_unflatten([t.to("cpu") for t in flattened], spec)
52
+
53
+
54
+ def save(tensors, *args, **kwargs):
55
+ torch.save(to_cpu(tensors), *args, **kwargs)
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-311.pyc ADDED
Binary file (859 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-311.pyc ADDED
Binary file (522 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/_lazy/computation.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch._C._lazy
3
+ import torch._C._lazy_ts_backend
4
+
5
+
6
+ def get_tensors_ts_device_data_node(tensors):
7
+ """Return tensor ids and eager tensors for DeviceData nodes in the
8
+ IR for the passed in lazy tensors.
9
+
10
+ TODO: This API is currently ts backend specific. We are working on
11
+ generalizing it to all backends including XLA.
12
+ """
13
+ return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
14
+
15
+
16
+ def get_graph_hash(tensors):
17
+ """Return the graph hash for the passed in lazy tensors"""
18
+ return torch._C._lazy._get_graph_hash(tensors)
19
+
20
+
21
+ def run_cached_graph(hash_str, graph_inputs):
22
+ """Running the cached computation graph with the given inputs
23
+
24
+ TODO: This API is currently ts backend specific. We are working on
25
+ generalizing it to all backends including XLA.
26
+ """
27
+ return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)
.venv/lib/python3.11/site-packages/torch/_lazy/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch._C._lazy
3
+
4
+
5
+ def get_force_fallback():
6
+ """Get the config used to force LTC fallback"""
7
+ return torch._C._lazy._get_force_fallback()
8
+
9
+
10
+ def set_force_fallback(configval):
11
+ """Set the config used to force LTC fallback"""
12
+ torch._C._lazy._set_force_fallback(configval)
13
+
14
+
15
+ def set_reuse_ir(val: bool):
16
+ """Set the config to reuse IR nodes for faster tracing"""
17
+ torch._C._lazy._set_reuse_ir(val)
.venv/lib/python3.11/site-packages/torch/_lazy/debug.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch._C._lazy
3
+
4
+
5
+ def render_ir_graph(tensors):
6
+ """Return a text dump of the LTC IR graph in dot format for the tensors.
7
+ The text can be processed by tools like dot to be rendered in pdf,png etc."""
8
+ return torch._C._lazy._get_tensors_dot(tensors)
9
+
10
+
11
+ def dump_ir(tensors, ir_format):
12
+ """Return a dump of the tensors in the specified format.
13
+ Valid format are
14
+ - text: for LTC IR
15
+ - backend: for the activate backend IR
16
+ """
17
+ if ir_format == "text":
18
+ return torch._C._lazy._get_tensors_text(tensors)
19
+ elif ir_format == "backend":
20
+ return torch._C._lazy._get_tensors_backend(tensors)
21
+ else:
22
+ raise RuntimeError(f"Unrecognized IR format: {ir_format}")
.venv/lib/python3.11/site-packages/torch/_lazy/device_context.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import threading
3
+ from typing import Any, Dict
4
+
5
+ import torch._C._lazy
6
+
7
+
8
+ class DeviceContext:
9
+ _CONTEXTS: Dict[str, Any] = {}
10
+ _CONTEXTS_LOCK = threading.Lock()
11
+
12
+ def __init__(self, device):
13
+ self.device = device
14
+
15
+
16
+ def get_device_context(device=None):
17
+ if device is None:
18
+ device = torch._C._lazy._get_default_device_type()
19
+ else:
20
+ device = str(device)
21
+ with DeviceContext._CONTEXTS_LOCK:
22
+ devctx = DeviceContext._CONTEXTS.get(device, None)
23
+ if devctx is None:
24
+ devctx = DeviceContext(device)
25
+ DeviceContext._CONTEXTS[device] = devctx
26
+ return devctx
.venv/lib/python3.11/site-packages/torch/_lazy/extract_compiled_graph.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import copy
3
+ import dataclasses
4
+ import itertools
5
+ import os
6
+ from typing import Any, Callable, Dict, List
7
+
8
+ import torch
9
+ import torch._lazy as lazy
10
+ import torch._lazy.metrics as metrics
11
+ from torch import fx
12
+ from torch._lazy import computation, debug as lazy_debug
13
+ from torch._lazy.tensor_factory_functions import tensor_factory_functions
14
+
15
+
16
+ debug = os.environ.get("debug_extract_compiled_graph") is not None
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class GraphInputMatcher:
21
+ """
22
+ The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing.
23
+ Specifically, those graph inputs corresponding to method parameters should be replaced with the
24
+ arguments for the current call.
25
+
26
+ tensor_id_to_arg_idx maps the tensor id to the parameter index.
27
+ graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the
28
+ TS/XLA graph inputs.
29
+ """
30
+
31
+ tensor_id_to_arg_idx: Dict[int, int]
32
+ graph_input_tensor_ids: List[int]
33
+ # there are 2 categories of graph_input_tensors.
34
+ # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
35
+ # most likely const tensors and we can get its content from graph_input_tensors
36
+ # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
37
+ # the tensor from method arguments
38
+ graph_input_ivalues: List[Any]
39
+
40
+ # get the real graph input tensors
41
+ def __call__(self, args):
42
+ real_input = []
43
+ for tensor_id, traced_ivalue in zip(
44
+ self.graph_input_tensor_ids, self.graph_input_ivalues
45
+ ):
46
+ arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
47
+ if arg_idx is None:
48
+ inp = traced_ivalue
49
+ else:
50
+ inp = args[arg_idx]
51
+ real_input.append(inp)
52
+ return real_input
53
+
54
+
55
+ class ReturnValueHandler:
56
+ r"""
57
+ When ltc_sync_multi is called on multi tensors, the compiled graph
58
+ will contain output only for unique tensors - if a tensor appears multiple
59
+ times in the input to _ltc_sync_multi, only the first occurance matters.
60
+
61
+ However from python level, we still expect multi tensors returned with duplciation
62
+ even if the TS graph dedup the output. e.g. for method:
63
+
64
+ def forward(self, a):
65
+ return a, a
66
+
67
+ the TS graph captured by LTC will return a single tensor, but Python method expects 2.
68
+
69
+ This class dedup the lazy tensors first to get the index that will be used
70
+ to duplicate the eager tensors later.
71
+ """
72
+
73
+ def __init__(self, lazy_out_list):
74
+ self.index: List[List[int]] = []
75
+ self.total_count = len(lazy_out_list)
76
+
77
+ tensor_id_to_idx: Dict[int, int] = {}
78
+ for dup_idx, lazy_tensor in enumerate(lazy_out_list):
79
+ uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
80
+ if uniq_idx is not None:
81
+ self.index[uniq_idx].append(dup_idx)
82
+ else:
83
+ uniq_idx = len(self.index)
84
+ self.index.append([dup_idx])
85
+ tensor_id_to_idx[id(lazy_tensor)] = uniq_idx
86
+
87
+ def duplicate_eager_tensors(self, eager_tensor_list):
88
+ duplicated_list = [None] * self.total_count
89
+ assert len(eager_tensor_list) == len(self.index)
90
+
91
+ for uniq_idx, eager_tensor in enumerate(eager_tensor_list):
92
+ for dup_idx in self.index[uniq_idx]:
93
+ duplicated_list[dup_idx] = eager_tensor
94
+ return duplicated_list
95
+
96
+
97
+ def force_lazy_device(model: fx.GraphModule):
98
+ """
99
+ Factory methods in a Fx graph may create tensors for a specific eager devices.
100
+ If we take no actions, those eager tensors will be mixed with lazy tensors and
101
+ cause crash. This method overwrite those eager device to lazy device.
102
+ """
103
+
104
+ def tolazydevice(dev):
105
+ if isinstance(dev, torch.device):
106
+ return torch.device("lazy", index=dev.index)
107
+ return dev
108
+
109
+ def hasDeviceArg(args, kwargs):
110
+ return any(
111
+ isinstance(arg, torch.device)
112
+ for arg in itertools.chain(args, kwargs.values())
113
+ )
114
+
115
+ for nd in model.graph.nodes:
116
+ nd.args = tuple(tolazydevice(arg) for arg in nd.args)
117
+ nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}
118
+
119
+ # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
120
+ # eager tensors on the default device
121
+ # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
122
+ # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
123
+ # To force those tensors on the lazy device, we can not simply override
124
+ # the device argument since there is no explicit device argument.
125
+ # What we are doing here is, for the list of covered tensor factory methods
126
+ # we add a lazy device argument explicity.
127
+ #
128
+ # TODO: This solution is no ideal since we may miss some factory methods. In future
129
+ # when we support lazy mode, this method can be replaced by that.
130
+ if nd.target in tensor_factory_functions and not hasDeviceArg(
131
+ nd.args, nd.kwargs
132
+ ):
133
+ kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy.
134
+ kwargs["device"] = torch.device("lazy")
135
+ nd.kwargs = kwargs
136
+
137
+ model.recompile()
138
+
139
+
140
+ def get_fallback_ops():
141
+ fallback_ops = []
142
+ for opname in metrics.counter_names():
143
+ if "aten::" not in opname:
144
+ continue
145
+ val = int(metrics.counter_value(opname))
146
+ if val > 0:
147
+ fallback_ops.append(f"{opname}={val}")
148
+
149
+ return fallback_ops
150
+
151
+
152
+ def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable:
153
+ """
154
+ Optimize an eager model with LTC and returns a wrapper to execute the
155
+ compiled graph directly without retracing. It depends on other mechanisms
156
+ like TorchDynamo guards to guarantee the returned wrapper is only called
157
+ when it's safe.
158
+ """
159
+ lazy_args = [arg.to(device="lazy") for arg in example_inputs]
160
+ args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args]
161
+ tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
162
+ lazy_model = copy.deepcopy(model).to(device=torch.device("lazy"))
163
+ force_lazy_device(lazy_model)
164
+
165
+ # This line executes lazy tracing and enable us extracting compiled graph later
166
+ metrics.reset()
167
+ lazy_out = lazy_model(*lazy_args)
168
+ fallback_ops = get_fallback_ops()
169
+ metrics.reset()
170
+
171
+ if len(fallback_ops) > 0:
172
+ raise RuntimeError(
173
+ f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
174
+ )
175
+
176
+ if not isinstance(lazy_out, (tuple, list)):
177
+ lazy_out = (lazy_out,)
178
+
179
+ args_and_out = tuple(lazy_args) + tuple(lazy_out)
180
+ return_value_handler = ReturnValueHandler(args_and_out)
181
+ if debug:
182
+ print("Fx code:\n", model.code)
183
+ print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text"))
184
+
185
+ # TODO: this part is TS backend specific for now and will be generalized to
186
+ # support XLA
187
+ (
188
+ graph_input_tensor_ids,
189
+ graph_input_ivalues,
190
+ ) = computation.get_tensors_ts_device_data_node(args_and_out)
191
+ assert len(graph_input_tensor_ids) == len(graph_input_ivalues)
192
+ graph_input_matcher = GraphInputMatcher(
193
+ tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
194
+ )
195
+
196
+ graph_hash = computation.get_graph_hash(args_and_out)
197
+
198
+ if debug:
199
+ print("graph_hash", graph_hash)
200
+ print(f"args_tensor_ids {args_tensor_ids}")
201
+ print("tensor ids from device data:", graph_input_tensor_ids)
202
+
203
+ # sync the list of output tensors so the computation graph for these
204
+ # tensors will be cached. Those computation graphs can be retrieved
205
+ # by graph hash later.
206
+ lazy.sync_multi(args_and_out, [])
207
+
208
+ def optimized_mod(*args):
209
+ if len(args_and_out) == 0:
210
+ return ()
211
+ graph_input = graph_input_matcher(args)
212
+ res = return_value_handler.duplicate_eager_tensors(
213
+ computation.run_cached_graph(graph_hash, graph_input)
214
+ )
215
+
216
+ assert len(res) == len(args_and_out)
217
+ for i, arg in enumerate(args):
218
+ # only copy those tensors that get inplace updated
219
+ if arg is not res[i]:
220
+ arg.copy_(res[i])
221
+
222
+ # skip the args
223
+ return res[len(args) :]
224
+
225
+ return optimized_mod
.venv/lib/python3.11/site-packages/torch/_lazy/metrics.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch._C._lazy
3
+
4
+
5
+ def reset():
6
+ """Resets all metric counters."""
7
+ torch._C._lazy._reset_metrics()
8
+
9
+
10
+ def counter_names():
11
+ """Retrieves all the currently active counter names."""
12
+ return torch._C._lazy._counter_names()
13
+
14
+
15
+ def counter_value(name: str):
16
+ """Return the value of the counter with the speficied name"""
17
+ return torch._C._lazy._counter_value(name)
18
+
19
+
20
+ def metrics_report():
21
+ """Return the combined (lazy core and backend) metric report"""
22
+ return torch._C._lazy._metrics_report()
.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import torch._C._lazy_ts_backend
3
+
4
+
5
+ def init():
6
+ """Initializes the lazy Torchscript backend"""
7
+ torch._C._lazy_ts_backend._init()
.venv/lib/python3.11/site-packages/torch/multiprocessing/_atfork.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import sys
3
+
4
+
5
+ __all__ = ["register_after_fork"]
6
+
7
+ if sys.platform == "win32":
8
+ import multiprocessing.util as _util
9
+
10
+ def _register(func):
11
+ def wrapper(arg):
12
+ func()
13
+
14
+ _util.register_after_fork(_register, wrapper)
15
+
16
+ else:
17
+ import os
18
+
19
+ def _register(func):
20
+ os.register_at_fork(after_in_child=func)
21
+
22
+
23
+ def register_after_fork(func):
24
+ """Register a callable to be executed in the child process after a fork.
25
+
26
+ Note:
27
+ In python < 3.7 this will only work with processes created using the
28
+ ``multiprocessing`` module. In python >= 3.7 it also works with
29
+ ``os.fork()``.
30
+
31
+ Args:
32
+ func (function): Function taking no arguments to be called in the child after fork
33
+
34
+ """
35
+ _register(func)
.venv/lib/python3.11/site-packages/torch/multiprocessing/pool.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing.pool
2
+ import multiprocessing.util as util
3
+
4
+ from .queue import SimpleQueue
5
+
6
+
7
+ def clean_worker(*args, **kwargs):
8
+ import gc
9
+
10
+ multiprocessing.pool.worker(*args, **kwargs)
11
+ # Regular multiprocessing workers don't fully clean up after themselves,
12
+ # so we have to explicitly trigger garbage collection to make sure that all
13
+ # destructors are called...
14
+ gc.collect()
15
+
16
+
17
+ class Pool(multiprocessing.pool.Pool):
18
+ """Pool implementation which uses our version of SimpleQueue.
19
+
20
+ This lets us pass tensors in shared memory across processes instead of
21
+ serializing the underlying data.
22
+ """
23
+
24
+ def _setup_queues(self):
25
+ self._inqueue = SimpleQueue()
26
+ self._outqueue = SimpleQueue()
27
+ self._quick_put = self._inqueue._writer.send
28
+ self._quick_get = self._outqueue._reader.recv
29
+
30
+ def _repopulate_pool(self):
31
+ """Increase the number of pool processes to the specified number.
32
+
33
+ Bring the number of pool processes up to the specified number, for use after
34
+ reaping workers which have exited.
35
+ """
36
+ for i in range(self._processes - len(self._pool)):
37
+ # changed worker -> clean_worker
38
+ args = (
39
+ self._inqueue,
40
+ self._outqueue,
41
+ self._initializer,
42
+ self._initargs,
43
+ self._maxtasksperchild,
44
+ )
45
+ if hasattr(self, "_wrap_exception"):
46
+ args += (self._wrap_exception,)
47
+ w = self.Process(target=clean_worker, args=args)
48
+ self._pool.append(w)
49
+ w.name = w.name.replace("Process", "PoolWorker")
50
+ w.daemon = True
51
+ w.start()
52
+ util.debug("added worker")
.venv/lib/python3.11/site-packages/torch/multiprocessing/queue.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import io
3
+ import multiprocessing.queues
4
+ import pickle
5
+ from multiprocessing.reduction import ForkingPickler
6
+
7
+
8
+ class ConnectionWrapper:
9
+ """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
10
+
11
+ def __init__(self, conn):
12
+ self.conn = conn
13
+
14
+ def send(self, obj):
15
+ buf = io.BytesIO()
16
+ ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
17
+ self.send_bytes(buf.getvalue())
18
+
19
+ def recv(self):
20
+ buf = self.recv_bytes()
21
+ return pickle.loads(buf)
22
+
23
+ def __getattr__(self, name):
24
+ if "conn" in self.__dict__:
25
+ return getattr(self.conn, name)
26
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
27
+
28
+
29
+ class Queue(multiprocessing.queues.Queue):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
33
+ self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
34
+ self._send = self._writer.send
35
+ self._recv = self._reader.recv
36
+
37
+
38
+ class SimpleQueue(multiprocessing.queues.SimpleQueue):
39
+ def _make_methods(self):
40
+ if not isinstance(self._reader, ConnectionWrapper):
41
+ self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
42
+ self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
43
+ super()._make_methods() # type: ignore[misc]
.venv/lib/python3.11/site-packages/torch/multiprocessing/reductions.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import multiprocessing
3
+ import os
4
+ import threading
5
+ from multiprocessing.reduction import ForkingPickler
6
+ from multiprocessing.util import register_after_fork
7
+ from typing import Union
8
+
9
+ import torch
10
+ from torch._namedtensor_internals import check_serializing_named_tensor
11
+
12
+
13
+ try:
14
+ # Early load resource_sharer to prevent a partially initialized instance
15
+ # from being inherited in a forked child process. The reduce_storage method
16
+ # requires this module indirectly through DupFd(). The built-in mp.Queue
17
+ # class pickles arguments in a background thread which may overlap with the
18
+ # fork.
19
+ import multiprocessing.resource_sharer
20
+ except ImportError:
21
+ pass
22
+
23
+
24
+ class StorageWeakRef:
25
+ r"""A weak reference to a Storage.
26
+
27
+ The cdata member is a Python number containing the integer representation of
28
+ the Storage pointer.
29
+ """
30
+
31
+ __slots__ = ["cdata", "_free_weak_ref"]
32
+
33
+ def __init__(self, storage):
34
+ self.cdata = storage._weak_ref()
35
+ # Save a direct reference to _free_weak_ref because the `torch` module
36
+ # might be cleared during Python shutdown before this module is cleared.
37
+ self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
38
+
39
+ @classmethod
40
+ def from_weakref(cls, cdata):
41
+ instance = cls.__new__(cls)
42
+ instance.cdata = cdata
43
+ instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined]
44
+ return instance
45
+
46
+ def expired(self):
47
+ return torch.Storage._expired(self.cdata) # type: ignore[attr-defined]
48
+
49
+ def __del__(self):
50
+ self._free_weak_ref(self.cdata)
51
+
52
+ def __hash__(self):
53
+ return self.cdata
54
+
55
+ def __eq__(self, other):
56
+ if id(self) == id(other):
57
+ return True
58
+ return self.cdata == other.cdata
59
+
60
+
61
+ class SharedCache(dict):
62
+ """Dictionary from multiprocessing handles to StorageWeakRef."""
63
+
64
+ def __init__(self) -> None:
65
+ # free_dead_references() is called if the len exceeds the current
66
+ # limit. The limit scales with the number of remaining live objects.
67
+ self.limit = 128
68
+ # `fork` inherits lock state, so in case we fork when the lock is held,
69
+ # we register a function to reset the lock to a new object to avoid
70
+ # possible deadlocks, following python multiprocessing library design.
71
+ self._after_fork()
72
+ register_after_fork(self, SharedCache._after_fork)
73
+
74
+ def _after_fork(self):
75
+ self.lock = threading.Lock()
76
+
77
+ def get(self, key):
78
+ with self.lock:
79
+ return dict.get(self, key)
80
+
81
+ def __setitem__(self, key, storage_ref):
82
+ with self.lock:
83
+ dict.__setitem__(self, key, storage_ref)
84
+ if len(self) > self.limit:
85
+ self.free_dead_references()
86
+
87
+ def free_dead_references(self):
88
+ live = 0
89
+ for key, storage_ref in list(self.items()):
90
+ if storage_ref.expired():
91
+ del self[key]
92
+ else:
93
+ live += 1
94
+ self.limit = max(128, live * 2)
95
+
96
+
97
+ # mapping from handles to StorageWeakRef objects
98
+ shared_cache = SharedCache()
99
+
100
+
101
+ def rebuild_event(device, handle):
102
+ return torch.cuda.Event.from_ipc_handle(device, handle)
103
+
104
+
105
+ def reduce_event(event):
106
+ handle = event.ipc_handle()
107
+ return (rebuild_event, (event.device, handle))
108
+
109
+
110
+ def rebuild_tensor(cls, storage, metadata):
111
+ storage_offset, size, stride, requires_grad = metadata
112
+ t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
113
+ if cls == torch.nn.parameter.Parameter:
114
+ # we have to pass requires_grad into constructor, rather than set it as an
115
+ # attribute later, because it's an important check for Integer Tensors to
116
+ # have requires_grad=False (or else they raise an error)
117
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
118
+ else:
119
+ t.requires_grad = requires_grad
120
+ return t
121
+
122
+
123
+ def rebuild_meta_tensor(
124
+ tensor_cls,
125
+ tensor_size,
126
+ tensor_stride,
127
+ tensor_offset,
128
+ dtype,
129
+ storage_size_bytes,
130
+ requires_grad,
131
+ ):
132
+ untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
133
+
134
+ typed_storage = torch.TypedStorage(
135
+ wrap_storage=untyped_storage, dtype=dtype, _internal=True
136
+ )
137
+
138
+ t = torch._utils._rebuild_tensor(
139
+ typed_storage,
140
+ tensor_offset,
141
+ tensor_size,
142
+ tensor_stride,
143
+ )
144
+
145
+ if tensor_cls == torch.nn.parameter.Parameter:
146
+ # It is crucial for integer tensors to receive
147
+ # the requires_grad=False as an argument in the constructor
148
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
149
+ else:
150
+ t.requires_grad = requires_grad
151
+
152
+ return t
153
+
154
+
155
+ def rebuild_cuda_tensor(
156
+ tensor_cls,
157
+ tensor_size,
158
+ tensor_stride,
159
+ tensor_offset,
160
+ storage_cls,
161
+ dtype,
162
+ storage_device,
163
+ storage_handle,
164
+ storage_size_bytes,
165
+ storage_offset_bytes,
166
+ requires_grad,
167
+ ref_counter_handle,
168
+ ref_counter_offset,
169
+ event_handle,
170
+ event_sync_required,
171
+ ):
172
+ # If storage_handle is None, storage points to nullptr.
173
+ if storage_handle is None or storage_size_bytes == 0:
174
+ storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True)
175
+ else:
176
+ storage = storage_from_cache(
177
+ storage_cls, (storage_handle, storage_offset_bytes)
178
+ )
179
+ if storage is None:
180
+ torch.cuda._lazy_init()
181
+ storage = storage_cls._new_shared_cuda(
182
+ storage_device,
183
+ storage_handle,
184
+ storage_size_bytes,
185
+ storage_offset_bytes,
186
+ ref_counter_handle,
187
+ ref_counter_offset,
188
+ event_handle,
189
+ event_sync_required,
190
+ )
191
+ shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(
192
+ storage
193
+ )
194
+ else:
195
+ # We already ref counting this Storage, but producer needs new ref-counters to be released.
196
+ storage_cls._release_ipc_counter(
197
+ ref_counter_handle, ref_counter_offset, device=storage_device
198
+ )
199
+
200
+ _storage = (
201
+ storage
202
+ if isinstance(storage, torch.UntypedStorage)
203
+ else storage._untyped_storage
204
+ )
205
+
206
+ t = torch._utils._rebuild_tensor(
207
+ torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True),
208
+ tensor_offset,
209
+ tensor_size,
210
+ tensor_stride,
211
+ )
212
+
213
+ if tensor_cls == torch.nn.parameter.Parameter:
214
+ # It is crucial for integer tensors to receive
215
+ # the requires_grad=False as an argument in the constructor
216
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
217
+ else:
218
+ t.requires_grad = requires_grad
219
+
220
+ return t
221
+
222
+
223
+ def reduce_tensor(tensor):
224
+ if tensor.requires_grad and not tensor.is_leaf:
225
+ raise RuntimeError(
226
+ "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
227
+ "since autograd does not support crossing process boundaries. "
228
+ "If you just want to transfer the data, call detach() on the tensor "
229
+ "before serializing (e.g., putting it on the queue)."
230
+ )
231
+
232
+ check_serializing_named_tensor(tensor)
233
+ torch.utils.hooks.warn_if_has_hooks(tensor)
234
+
235
+ # Note [CUDA IPC and the caching allocator]
236
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
237
+ # When you send a CUDA tensor over IPC, you might expect that you will
238
+ # get out the same storage from the other end. However, the CUDA caching
239
+ # allocator makes it difficult to preserve this invariant. Consider
240
+ # the following situation: a tensor of size 0x100 points to offset 0x20 of
241
+ # a storage at 0xA100 of size 0x100. (For simplicity, all of these
242
+ # sizes are given in bytes). HOWEVER, with the caching allocator, this storage
243
+ # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000.
244
+ #
245
+ # When we want to send this CUDA tensor over IPC, we must send the
246
+ # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just
247
+ # the storage 0xA100 (because that is what CUDA supports). So, on the
248
+ # other end, there simply isn't any way to say, "Wait, you gave me
249
+ # a bigger region (0xA000) than the one I wanted (0xA100)".
250
+ #
251
+ # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as
252
+ # one storage itself? No, because this cudaMalloc allocation might contain
253
+ # storages of mixed types: float, bytes, double... If you make the entire
254
+ # allocation a single storage of a type A, we'll hit an error when constructing
255
+ # a tensor of type B on the storage.
256
+ #
257
+ # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the
258
+ # receiver side. However, cudaIpcMemHandles from each device in a given process may
259
+ # only be opened by one context per device per other process.
260
+ # If we open and close a memory handle multiples times in a process, CUDA is allowed
261
+ # to give it a different address; similarly, once we close the memory, we're not
262
+ # allowed to access it(and the storage/tensor built on top of it), even if it is
263
+ # still live in the original process. As we cannot make a cudaMalloc allocation
264
+ # to a single storage in one go, this requires us to cache the device pointer for
265
+ # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep
266
+ # the old ones alives.
267
+ # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html]
268
+ #
269
+ # This is fine, because all we need to do is to save our position in the allocation,
270
+ # and reconstruct storage and tensor from it.
271
+ # 0xA000 -> -------CUDA Allocation------
272
+ # | |
273
+ # | |
274
+ # | |
275
+ # | |
276
+ # 0xA100 -> --------storage1 begin------
277
+ # | |
278
+ # 0xA120 -> --------tensor1 begin ------
279
+ # | |
280
+ # | |
281
+ # | |
282
+ # | |
283
+ # | |
284
+ # 0xA160 -> --------tensor1 end---------
285
+ # | |
286
+ # | |
287
+ # | |
288
+ # 0xA200 -> --------storage1 end--------
289
+ # | |
290
+ # 0xE000 -> --------CUDA allocation-----
291
+ #
292
+ # To send tensor1, the following info are required from sender to receiver for
293
+ # storage recontruction.
294
+ # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process).
295
+ # basePtr may not be exactly 0xA000 since it's a different process.
296
+ # 2. offset(0xA100) of storage1 in the CUDA allocation.
297
+ # 3. size of storage1(0x100).
298
+ #
299
+ # On receiver side:
300
+ # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage
301
+ # of the same type using (basePtr, offset, size).
302
+ # 2. we can reconstruct the tensor on top of the reconstructed storage
303
+ # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100))
304
+ #
305
+ # This strategy has a few implications:
306
+ #
307
+ # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one
308
+ # go (non-compositionally), and this requires to have a global map
309
+ # memHandle -> devPtr for each process.
310
+ #
311
+ # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize
312
+ # of the storage beyond 0x100 would merely have caused us to do a
313
+ # reallocation. You don't really want to do this, but if you did,
314
+ # all that would happen is that you would lose IPC sharing. But if
315
+ # you do this in the new world, we will happily let you write out of
316
+ # bounds of your "allocation", clobbering unrelated data in the cached
317
+ # allocator block. BAD!
318
+ #
319
+ # By the way, in old versions of PyTorch, we supported this situation
320
+ # natively using a "storage view", which permitted multiple storages to be
321
+ # views on each other. But this was the *only* use of storage views, so we
322
+ # eliminated it so that we could just use tensor views to implement the same
323
+ # thing.
324
+ #
325
+
326
+ # TODO: Handle distinguishing between subclass and non-subclass versions of NT better
327
+ # https://github.com/pytorch/pytorch/issues/110543
328
+ from torch.nested._internal.nested_tensor import NestedTensor
329
+
330
+ if tensor.is_nested and not isinstance(tensor, NestedTensor):
331
+ return reduce_nested_tensor(tensor)
332
+
333
+ if tensor.layout in {
334
+ torch.sparse_coo,
335
+ torch.sparse_csr,
336
+ torch.sparse_bsr,
337
+ torch.sparse_csc,
338
+ torch.sparse_bsc,
339
+ }:
340
+ return reduce_sparse_tensor(tensor)
341
+
342
+ storage = tensor._typed_storage()
343
+
344
+ if storage._untyped_storage.device.type == "cuda":
345
+ (
346
+ device,
347
+ handle,
348
+ storage_size_bytes,
349
+ storage_offset_bytes,
350
+ ref_counter_handle,
351
+ ref_counter_offset,
352
+ event_handle,
353
+ event_sync_required,
354
+ ) = storage._share_cuda_()
355
+ tensor_offset = tensor.storage_offset()
356
+ shared_cache[handle] = StorageWeakRef(storage)
357
+ # _backward_hooks purposely omitted here, see
358
+ # Note [Don't serialize hooks]
359
+ return (
360
+ rebuild_cuda_tensor,
361
+ (
362
+ type(tensor),
363
+ tensor.size(),
364
+ tensor.stride(),
365
+ tensor_offset, # tensor offset in its storage
366
+ type(storage),
367
+ tensor.dtype,
368
+ device,
369
+ handle, # identifier which CUDA allocation is the storage in.
370
+ storage_size_bytes, # size(in bytes) of the storage
371
+ storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation
372
+ tensor.requires_grad,
373
+ ref_counter_handle,
374
+ ref_counter_offset,
375
+ event_handle,
376
+ event_sync_required,
377
+ ),
378
+ )
379
+ elif storage._untyped_storage.device.type == "meta":
380
+ return (
381
+ rebuild_meta_tensor,
382
+ (
383
+ type(tensor),
384
+ tensor.size(),
385
+ tensor.stride(),
386
+ tensor.storage_offset(),
387
+ tensor.dtype,
388
+ tensor.untyped_storage().size(),
389
+ tensor.requires_grad,
390
+ ),
391
+ )
392
+
393
+ # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
394
+ metadata = (
395
+ tensor.storage_offset(),
396
+ tensor.size(),
397
+ tensor.stride(),
398
+ tensor.requires_grad,
399
+ )
400
+ return (rebuild_tensor, (type(tensor), storage, metadata))
401
+
402
+
403
+ def rebuild_nested_tensor(
404
+ rebuild_buffer_func,
405
+ rebuild_buffer_args,
406
+ rebuild_sizes_func,
407
+ rebuild_sizes_args,
408
+ rebuild_strides_func,
409
+ rebuild_strides_args,
410
+ rebuild_offsets_func,
411
+ rebuild_offsets_args,
412
+ ):
413
+ buffer = rebuild_buffer_func(*rebuild_buffer_args)
414
+ sizes = rebuild_sizes_func(*rebuild_sizes_args)
415
+ strides = rebuild_strides_func(*rebuild_strides_args)
416
+ offsets = rebuild_offsets_func(*rebuild_offsets_args)
417
+ return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets)
418
+
419
+
420
+ def reduce_nested_tensor(nt):
421
+ rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values())
422
+ rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size())
423
+ rebuild_strides_func, rebuild_strides_args = reduce_tensor(
424
+ nt._nested_tensor_strides()
425
+ )
426
+ rebuild_offsets_func, rebuild_offsets_args = reduce_tensor(
427
+ nt._nested_tensor_storage_offsets()
428
+ )
429
+
430
+ return (
431
+ rebuild_nested_tensor,
432
+ (
433
+ rebuild_buffer_func,
434
+ rebuild_buffer_args,
435
+ rebuild_sizes_func,
436
+ rebuild_sizes_args,
437
+ rebuild_strides_func,
438
+ rebuild_strides_args,
439
+ rebuild_offsets_func,
440
+ rebuild_offsets_args,
441
+ ),
442
+ )
443
+
444
+
445
+ def rebuild_sparse_coo_tensor(
446
+ rebuild_indices_func,
447
+ rebuild_indices_args,
448
+ rebuild_values_func,
449
+ rebuild_values_args,
450
+ shape,
451
+ is_coalesced,
452
+ ):
453
+ indices = rebuild_indices_func(*rebuild_indices_args)
454
+ values = rebuild_values_func(*rebuild_values_args)
455
+ return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced)
456
+
457
+
458
+ def rebuild_sparse_compressed_tensor(
459
+ rebuild_compressed_indices_func,
460
+ rebuild_compressed_indices_args,
461
+ rebuild_plain_indices_func,
462
+ rebuild_plain_indices_args,
463
+ rebuild_values_func,
464
+ rebuild_values_args,
465
+ shape,
466
+ layout,
467
+ ):
468
+ compressed_indices = rebuild_compressed_indices_func(
469
+ *rebuild_compressed_indices_args
470
+ )
471
+ plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
472
+ values = rebuild_values_func(*rebuild_values_args)
473
+ return torch.sparse_compressed_tensor(
474
+ compressed_indices, plain_indices, values, shape, layout=layout
475
+ )
476
+
477
+
478
+ def reduce_sparse_tensor(sparse):
479
+ if sparse.layout is torch.sparse_coo:
480
+ rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices())
481
+ rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values())
482
+ return (
483
+ rebuild_sparse_coo_tensor,
484
+ (
485
+ rebuild_indices_func,
486
+ rebuild_indices_args,
487
+ rebuild_values_func,
488
+ rebuild_values_args,
489
+ sparse.shape,
490
+ sparse.is_coalesced(),
491
+ ),
492
+ )
493
+ else:
494
+ if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}:
495
+ compressed_indices = sparse.crow_indices()
496
+ plain_indices = sparse.col_indices()
497
+ elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}:
498
+ compressed_indices = sparse.ccol_indices()
499
+ plain_indices = sparse.row_indices()
500
+ else:
501
+ raise NotImplementedError(sparse.layout)
502
+ (
503
+ rebuild_compressed_indices_func,
504
+ rebuild_compressed_indices_args,
505
+ ) = reduce_tensor(compressed_indices)
506
+ rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor(
507
+ plain_indices
508
+ )
509
+ rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values())
510
+ return (
511
+ rebuild_sparse_compressed_tensor,
512
+ (
513
+ rebuild_compressed_indices_func,
514
+ rebuild_compressed_indices_args,
515
+ rebuild_plain_indices_func,
516
+ rebuild_plain_indices_args,
517
+ rebuild_values_func,
518
+ rebuild_values_args,
519
+ sparse.shape,
520
+ sparse.layout,
521
+ ),
522
+ )
523
+
524
+
525
+ def fd_id(fd):
526
+ # Returns a tuple which uniquely identifies a file descriptor. In Mac OS,
527
+ # this doesn't work with shared memory handles, which is why we don't
528
+ # support the "file_descriptor" sharing method on that platform.
529
+ stat = os.fstat(fd)
530
+ return (stat.st_ino, stat.st_dev)
531
+
532
+
533
+ def storage_from_cache(cls, key):
534
+ storage_ref = shared_cache.get(key)
535
+ if storage_ref is None:
536
+ return None
537
+ return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
538
+
539
+
540
+ def rebuild_storage_fd(cls, df, size):
541
+ fd = df.detach()
542
+ try:
543
+ storage = storage_from_cache(cls, fd_id(fd))
544
+ if storage is not None:
545
+ return storage
546
+ storage = cls._new_shared_fd_cpu(fd, size)
547
+ shared_cache[fd_id(fd)] = StorageWeakRef(storage)
548
+ return storage
549
+ finally:
550
+ os.close(fd)
551
+
552
+
553
+ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
554
+ storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache(
555
+ cls, handle
556
+ )
557
+ if storage is not None:
558
+ return storage._shared_decref()
559
+ if dtype is None:
560
+ storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)
561
+ else:
562
+ byte_size = size * torch._utils._element_size(dtype)
563
+ untyped_storage: torch.UntypedStorage = (
564
+ torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
565
+ )
566
+ storage = torch.TypedStorage(
567
+ wrap_storage=untyped_storage, dtype=dtype, _internal=True
568
+ )
569
+ shared_cache[handle] = StorageWeakRef(storage)
570
+ return storage._shared_decref()
571
+
572
+
573
+ def rebuild_storage_empty(cls):
574
+ return cls()
575
+
576
+
577
+ def rebuild_typed_storage(storage, dtype):
578
+ return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True)
579
+
580
+
581
+ # Use for torch.storage.TypedStorage
582
+ def reduce_typed_storage(storage):
583
+ return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype))
584
+
585
+
586
+ def rebuild_typed_storage_child(storage, storage_type):
587
+ return storage_type(wrap_storage=storage, _internal=True)
588
+
589
+
590
+ # Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage
591
+ def reduce_typed_storage_child(storage):
592
+ return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage)))
593
+
594
+
595
+ def reduce_storage(storage):
596
+ from . import get_sharing_strategy
597
+
598
+ if storage.is_cuda:
599
+ raise RuntimeError(
600
+ "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
601
+ )
602
+ elif storage.device.type == "meta":
603
+ raise RuntimeError(
604
+ "Cannot pickle meta storage; try pickling a meta tensor instead"
605
+ )
606
+ elif get_sharing_strategy() == "file_system":
607
+ metadata = storage._share_filename_cpu_()
608
+ cache_key = metadata[1]
609
+ rebuild = rebuild_storage_filename
610
+ if isinstance(storage, torch.TypedStorage):
611
+ metadata += (storage.dtype,)
612
+ storage._shared_incref()
613
+ elif storage.size() == 0:
614
+ # This is special cased because Empty tensors
615
+ # (with size 0) cannot be mmapped.
616
+ return (rebuild_storage_empty, (type(storage),))
617
+ else:
618
+ fd, size = storage._share_fd_cpu_()
619
+ df = multiprocessing.reduction.DupFd(fd)
620
+ cache_key = fd_id(fd)
621
+ metadata = (df, size)
622
+ rebuild = rebuild_storage_fd # type: ignore[assignment]
623
+
624
+ shared_cache[cache_key] = StorageWeakRef(storage)
625
+ return (rebuild, (type(storage),) + metadata)
626
+
627
+
628
+ def init_reductions():
629
+ ForkingPickler.register(torch.cuda.Event, reduce_event)
630
+
631
+ for t in torch._storage_classes:
632
+ if t.__name__ == "UntypedStorage":
633
+ ForkingPickler.register(t, reduce_storage)
634
+ else:
635
+ ForkingPickler.register(t, reduce_typed_storage_child)
636
+
637
+ ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage)
638
+
639
+ for t in torch._tensor_classes:
640
+ ForkingPickler.register(t, reduce_tensor)
641
+
642
+ # TODO: Maybe this should be in tensor_classes? :)
643
+ ForkingPickler.register(torch.Tensor, reduce_tensor)
644
+
645
+ from torch.nn.parameter import Parameter
646
+
647
+ ForkingPickler.register(Parameter, reduce_tensor)
.venv/lib/python3.11/site-packages/torch/multiprocessing/spawn.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import multiprocessing
4
+ import multiprocessing.connection
5
+ import os
6
+ import pickle
7
+ import signal
8
+ import sys
9
+ import tempfile
10
+ import time
11
+ import warnings
12
+ from concurrent.futures import as_completed, ThreadPoolExecutor
13
+ from typing import Optional
14
+
15
+ from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
16
+
17
+
18
+ ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+ __all__ = [
23
+ "ProcessContext",
24
+ "ProcessException",
25
+ "ProcessExitedException",
26
+ "ProcessRaisedException",
27
+ "spawn",
28
+ "SpawnContext",
29
+ "start_processes",
30
+ ]
31
+
32
+
33
+ class ProcessException(Exception):
34
+ __slots__ = ["error_index", "error_pid"]
35
+
36
+ def __init__(self, msg: str, error_index: int, pid: int):
37
+ super().__init__(msg)
38
+ self.msg = msg
39
+ self.error_index = error_index
40
+ self.pid = pid
41
+
42
+ def __reduce__(self):
43
+ return type(self), (self.msg, self.error_index, self.pid)
44
+
45
+
46
+ class ProcessRaisedException(ProcessException):
47
+ """Exception raised when a process failed due to an exception raised by the code."""
48
+
49
+ def __init__(
50
+ self,
51
+ msg: str,
52
+ error_index: int,
53
+ error_pid: int,
54
+ ):
55
+ super().__init__(msg, error_index, error_pid)
56
+
57
+
58
+ class ProcessExitedException(ProcessException):
59
+ """Exception raised when a process failed due to signal or exited with a specific code."""
60
+
61
+ __slots__ = ["exit_code"]
62
+
63
+ def __init__(
64
+ self,
65
+ msg: str,
66
+ error_index: int,
67
+ error_pid: int,
68
+ exit_code: int,
69
+ signal_name: Optional[str] = None,
70
+ ):
71
+ super().__init__(msg, error_index, error_pid)
72
+ self.exit_code = exit_code
73
+ self.signal_name = signal_name
74
+
75
+ def __reduce__(self):
76
+ return (
77
+ type(self),
78
+ (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
79
+ )
80
+
81
+
82
+ def _wrap(fn, i, args, error_file):
83
+ # prctl(2) is a Linux specific system call.
84
+ # On other systems the following function call has no effect.
85
+ # This is set to ensure that non-daemonic child processes can
86
+ # terminate if their parent terminates before they do.
87
+ _prctl_pr_set_pdeathsig(signal.SIGINT)
88
+
89
+ try:
90
+ fn(i, *args)
91
+ except KeyboardInterrupt:
92
+ pass # SIGINT; Killed by parent, do nothing
93
+ except Exception:
94
+ # Propagate exception to parent process, keeping original traceback
95
+ import traceback
96
+
97
+ with open(error_file, "wb") as fh:
98
+ pickle.dump(traceback.format_exc(), fh)
99
+ sys.exit(1)
100
+
101
+
102
+ class ProcessContext:
103
+ def __init__(self, processes, error_files):
104
+ self.error_files = error_files
105
+ self.processes = processes
106
+ self.sentinels = {
107
+ process.sentinel: index for index, process in enumerate(processes)
108
+ }
109
+
110
+ def pids(self):
111
+ return [int(process.pid) for process in self.processes]
112
+
113
+ def join(self, timeout=None):
114
+ r"""Join one or more processes within spawn context.
115
+
116
+ Attempt to join one or more processes in this spawn context.
117
+ If one of them exited with a non-zero exit status, this function
118
+ kills the remaining processes and raises an exception with the cause
119
+ of the first process exiting.
120
+
121
+ Returns ``True`` if all processes have been joined successfully,
122
+ ``False`` if there are more processes that need to be joined.
123
+
124
+ Args:
125
+ timeout (float): Wait this long before giving up on waiting.
126
+ """
127
+ # Ensure this function can be called even when we're done.
128
+ if len(self.sentinels) == 0:
129
+ return True
130
+
131
+ # Wait for any process to fail or all of them to succeed.
132
+ ready = multiprocessing.connection.wait(
133
+ self.sentinels.keys(),
134
+ timeout=timeout,
135
+ )
136
+
137
+ error_index = None
138
+ for sentinel in ready:
139
+ index = self.sentinels.pop(sentinel)
140
+ process = self.processes[index]
141
+ process.join()
142
+ if process.exitcode != 0:
143
+ error_index = index
144
+ break
145
+
146
+ # Return if there was no error.
147
+ if error_index is None:
148
+ # Return whether or not all processes have been joined.
149
+ return len(self.sentinels) == 0
150
+
151
+ # Assume failure. Terminate processes that are still alive.
152
+ # Try SIGTERM then SIGKILL if the process isn't going down.
153
+ # The reason is related to python signal handling is limited
154
+ # to main thread and if that is in c/c++ land and stuck it won't
155
+ # to handle it. We have seen processes getting stuck not handling
156
+ # SIGTERM for the above reason.
157
+ timeout: int = 30
158
+ for process in self.processes:
159
+ if process.is_alive():
160
+ log.warning("Terminating process %s via signal SIGTERM", process.pid)
161
+ process.terminate()
162
+ end = time.monotonic() + timeout
163
+ for process in self.processes:
164
+ time_to_wait = max(0, end - time.monotonic())
165
+ process.join(time_to_wait)
166
+ for process in self.processes:
167
+ if process.is_alive():
168
+ log.warning(
169
+ "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
170
+ process.pid,
171
+ )
172
+ process.kill()
173
+ process.join()
174
+
175
+ # The file will only be created if the process crashed.
176
+ failed_process = self.processes[error_index]
177
+ if not os.access(self.error_files[error_index], os.R_OK):
178
+ exitcode = self.processes[error_index].exitcode
179
+ if exitcode < 0:
180
+ try:
181
+ name = signal.Signals(-exitcode).name
182
+ except ValueError:
183
+ name = f"<Unknown signal {-exitcode}>"
184
+ raise ProcessExitedException(
185
+ "process %d terminated with signal %s" % (error_index, name),
186
+ error_index=error_index,
187
+ error_pid=failed_process.pid,
188
+ exit_code=exitcode,
189
+ signal_name=name,
190
+ )
191
+ else:
192
+ raise ProcessExitedException(
193
+ "process %d terminated with exit code %d" % (error_index, exitcode),
194
+ error_index=error_index,
195
+ error_pid=failed_process.pid,
196
+ exit_code=exitcode,
197
+ )
198
+
199
+ with open(self.error_files[error_index], "rb") as fh:
200
+ original_trace = pickle.load(fh)
201
+ msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
202
+ msg += original_trace
203
+ raise ProcessRaisedException(msg, error_index, failed_process.pid)
204
+
205
+
206
+ class SpawnContext(ProcessContext):
207
+ def __init__(self, processes, error_files):
208
+ warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
209
+ super().__init__(processes, error_files)
210
+
211
+
212
+ # Note: [start_processes]
213
+ # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
214
+ # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
215
+ # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
216
+ # works better than 'spawn'. Every helper function we created for mp.spawn is indeed
217
+ # general enough, and backends like XLA can reuse them in Colab notebooks as well.
218
+ # Currently we only add this API first, we can consider adding it to documentation as
219
+ # needed in the future.
220
+ def start_processes(
221
+ fn,
222
+ args=(),
223
+ nprocs=1,
224
+ join=True,
225
+ daemon=False,
226
+ start_method="spawn",
227
+ ):
228
+ # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
229
+ # this func will start processes in parallel if start_method is 'forkserver'.
230
+ # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
231
+ # todo: investigate why spawn does not work with threadpool and raises SIGINT
232
+ if (
233
+ start_method == "forkserver"
234
+ and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
235
+ ):
236
+ log.info("Starting processes in parallel.")
237
+ start_parallel = True
238
+ else:
239
+ # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
240
+ start_parallel = False
241
+
242
+ mp = multiprocessing.get_context(start_method)
243
+ error_files = [None] * nprocs
244
+ processes = [None] * nprocs
245
+
246
+ def start_process(i):
247
+ # Each process is assigned a file to write tracebacks to. We
248
+ # use the file being non-empty to indicate an exception
249
+ # occurred (vs an expected shutdown). Note: this previously
250
+ # used a multiprocessing.Queue but that can be prone to
251
+ # deadlocks, so we went with a simpler solution for a one-shot
252
+ # message between processes.
253
+ tf = tempfile.NamedTemporaryFile(
254
+ prefix="pytorch-errorfile-", suffix=".pickle", delete=False
255
+ )
256
+ tf.close()
257
+ os.unlink(tf.name)
258
+ process = mp.Process(
259
+ target=_wrap,
260
+ args=(fn, i, args, tf.name),
261
+ daemon=daemon,
262
+ )
263
+ process.start()
264
+ return i, process, tf.name
265
+
266
+ if not start_parallel:
267
+ for i in range(nprocs):
268
+ idx, process, tf_name = start_process(i)
269
+ error_files[idx] = tf_name
270
+ processes[idx] = process
271
+ else:
272
+ with ThreadPoolExecutor(max_workers=nprocs) as executor:
273
+ futures = [executor.submit(start_process, i) for i in range(nprocs)]
274
+ for fut in as_completed(futures):
275
+ idx, process, tf_name = fut.result()
276
+ # idx and process rank needs to be the same.
277
+ error_files[idx] = tf_name
278
+ processes[idx] = process
279
+ context = ProcessContext(processes, error_files)
280
+ if not join:
281
+ return context
282
+
283
+ # Loop on join until it returns True or raises an exception.
284
+ while not context.join():
285
+ pass
286
+
287
+
288
+ def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
289
+ r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
290
+
291
+ If one of the processes exits with a non-zero exit status, the
292
+ remaining processes are killed and an exception is raised with the
293
+ cause of termination. In the case an exception was caught in the
294
+ child process, it is forwarded and its traceback is included in
295
+ the exception raised in the parent process.
296
+
297
+ Args:
298
+ fn (function): Function is called as the entrypoint of the
299
+ spawned process. This function must be defined at the top
300
+ level of a module so it can be pickled and spawned. This
301
+ is a requirement imposed by multiprocessing.
302
+
303
+ The function is called as ``fn(i, *args)``, where ``i`` is
304
+ the process index and ``args`` is the passed through tuple
305
+ of arguments.
306
+
307
+ args (tuple): Arguments passed to ``fn``.
308
+ nprocs (int): Number of processes to spawn.
309
+ join (bool): Perform a blocking join on all processes.
310
+ daemon (bool): The spawned processes' daemon flag. If set to True,
311
+ daemonic processes will be created.
312
+ start_method (str): (deprecated) this method will always use ``spawn``
313
+ as the start method. To use a different start method
314
+ use ``start_processes()``.
315
+
316
+ Returns:
317
+ None if ``join`` is ``True``,
318
+ :class:`~ProcessContext` if ``join`` is ``False``
319
+
320
+ """
321
+ if start_method != "spawn":
322
+ msg = (
323
+ f"This method only supports start_method=spawn (got: {start_method}).\n"
324
+ "To use a different start_method use:\n\t\t"
325
+ " torch.multiprocessing.start_processes(...)"
326
+ )
327
+ warnings.warn(msg, FutureWarning, stacklevel=2)
328
+ return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
.venv/lib/python3.11/site-packages/torch/nn/quantizable/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (246 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
2
+ from torch.ao.nn.quantizable.modules.rnn import LSTM, LSTMCell
3
+
4
+
5
+ __all__ = [
6
+ "LSTM",
7
+ "LSTMCell",
8
+ "MultiheadAttention",
9
+ ]
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (454 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/activation.cpython-311.pyc ADDED
Binary file (669 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantizable/modules/__pycache__/rnn.cpython-311.pyc ADDED
Binary file (665 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (253 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.37 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-311.pyc ADDED
Binary file (878 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-311.pyc ADDED
Binary file (679 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import parametrizations, rnn, stateless
2
+ from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
3
+ from .convert_parameters import parameters_to_vector, vector_to_parameters
4
+ from .fusion import (
5
+ fuse_conv_bn_eval,
6
+ fuse_conv_bn_weights,
7
+ fuse_linear_bn_eval,
8
+ fuse_linear_bn_weights,
9
+ )
10
+ from .init import skip_init
11
+ from .memory_format import (
12
+ convert_conv2d_weight_memory_format,
13
+ convert_conv3d_weight_memory_format,
14
+ )
15
+ from .spectral_norm import remove_spectral_norm, spectral_norm
16
+ from .weight_norm import remove_weight_norm, weight_norm
17
+
18
+
19
+ __all__ = [
20
+ "clip_grad_norm",
21
+ "clip_grad_norm_",
22
+ "clip_grad_value_",
23
+ "convert_conv2d_weight_memory_format",
24
+ "convert_conv3d_weight_memory_format",
25
+ "fuse_conv_bn_eval",
26
+ "fuse_conv_bn_weights",
27
+ "fuse_linear_bn_eval",
28
+ "fuse_linear_bn_weights",
29
+ "parameters_to_vector",
30
+ "parametrizations",
31
+ "remove_spectral_norm",
32
+ "remove_weight_norm",
33
+ "rnn",
34
+ "skip_init",
35
+ "spectral_norm",
36
+ "stateless",
37
+ "vector_to_parameters",
38
+ "weight_norm",
39
+ ]
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.25 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_deprecation_utils.cpython-311.pyc ADDED
Binary file (2.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_named_member_accessor.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/_per_sample_grad.cpython-311.pyc ADDED
Binary file (6.85 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/clip_grad.cpython-311.pyc ADDED
Binary file (9.91 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/convert_parameters.cpython-311.pyc ADDED
Binary file (3.61 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/fusion.cpython-311.pyc ADDED
Binary file (7.23 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/init.cpython-311.pyc ADDED
Binary file (2.75 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/memory_format.cpython-311.pyc ADDED
Binary file (8.39 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrizations.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/parametrize.cpython-311.pyc ADDED
Binary file (35.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/prune.cpython-311.pyc ADDED
Binary file (59.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/rnn.cpython-311.pyc ADDED
Binary file (28.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/spectral_norm.cpython-311.pyc ADDED
Binary file (17 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/stateless.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/__pycache__/weight_norm.cpython-311.pyc ADDED
Binary file (8.14 kB). View file
 
.venv/lib/python3.11/site-packages/torch/nn/utils/_deprecation_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import importlib
3
+ import warnings
4
+ from typing import Callable, List
5
+
6
+
7
+ _MESSAGE_TEMPLATE = (
8
+ r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead."
9
+ )
10
+
11
+
12
+ def lazy_deprecated_import(
13
+ all: List[str],
14
+ old_module: str,
15
+ new_module: str,
16
+ ) -> Callable:
17
+ r"""Import utility to lazily import deprecated packages / modules / functional.
18
+
19
+ The old_module and new_module are also used in the deprecation warning defined
20
+ by the `_MESSAGE_TEMPLATE`.
21
+
22
+ Args:
23
+ all: The list of the functions that are imported. Generally, the module's
24
+ __all__ list of the module.
25
+ old_module: Old module location
26
+ new_module: New module location / Migrated location
27
+
28
+ Returns:
29
+ Callable to assign to the `__getattr__`
30
+
31
+ Usage:
32
+
33
+ # In the `torch/nn/quantized/functional.py`
34
+ from torch.nn.utils._deprecation_utils import lazy_deprecated_import
35
+ _MIGRATED_TO = "torch.ao.nn.quantized.functional"
36
+ __getattr__ = lazy_deprecated_import(
37
+ all=__all__,
38
+ old_module=__name__,
39
+ new_module=_MIGRATED_TO)
40
+ """
41
+ warning_message = _MESSAGE_TEMPLATE.format(
42
+ old_location=old_module, new_location=new_module
43
+ )
44
+
45
+ def getattr_dunder(name):
46
+ if name in all:
47
+ # We are using the "RuntimeWarning" to make sure it is not
48
+ # ignored by default.
49
+ warnings.warn(warning_message, RuntimeWarning)
50
+ package = importlib.import_module(new_module)
51
+ return getattr(package, name)
52
+ raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.")
53
+
54
+ return getattr_dunder