BryanW commited on
Commit
f5b5a3b
·
verified ·
1 Parent(s): 7f9dddc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py +299 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py +621 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py +39 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py +179 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py +55 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc +0 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc +0 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc +0 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc +0 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc +0 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc +0 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc +0 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc +0 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc +0 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc +0 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py +5 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py +175 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py +21 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py +47 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py +0 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py +32 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py +45 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py +1 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc +0 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc +0 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc +0 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc +0 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc +0 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc +0 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc +0 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc +0 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc +0 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc +0 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc +0 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py +111 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +254 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py +146 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py +304 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +99 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py +80 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py +417 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py +36 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py +189 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +676 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +121 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +65 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py +190 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py +0 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py +324 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift +377 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements CUDA graphs support for TorchDynamo backends.
3
+
4
+ CUDA graphs allow for capturing and replaying GPU operations, which can significantly
5
+ reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
6
+
7
+ - CUDA graph creation and management for both forward and backward passes
8
+ - Input mutation detection and handling
9
+ - Device compatibility checking
10
+ - Stack trace management for debugging
11
+ - Integration with TorchInductor's cudagraph trees
12
+
13
+ The backend supports two main modes:
14
+ 1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
15
+ 2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
16
+
17
+ Key components:
18
+ - CudagraphsBackend: Main backend class for CUDA graph integration
19
+ - Mutation detection utilities to ensure graph safety
20
+ - Device mapping and compatibility checks
21
+ - Stack trace collection for debugging
22
+ """
23
+
24
+ import functools
25
+ from collections import defaultdict
26
+ from collections.abc import Callable, Sequence
27
+ from typing import Any, Optional
28
+
29
+ import torch
30
+ import torch.fx
31
+ from torch._dynamo import config
32
+ from torch._dynamo.backends.common import aot_autograd
33
+ from torch._dynamo.backends.debugging import boxed_nop
34
+ from torch._inductor.cudagraph_utils import (
35
+ BoxedDeviceIndex,
36
+ check_multiple_devices_or_any_cpu_nodes,
37
+ format_default_skip_message,
38
+ get_mutation_stack_trace,
39
+ get_placeholder_info,
40
+ log_cudagraph_skip_and_bump_counter,
41
+ )
42
+ from torch._inductor.utils import (
43
+ BoxedBool,
44
+ count_tangents,
45
+ get_first_incompatible_cudagraph_node,
46
+ num_fw_fixed_arguments,
47
+ output_node,
48
+ )
49
+ from torch.multiprocessing.reductions import StorageWeakRef
50
+
51
+ from .registry import register_backend
52
+
53
+
54
+ def find_input_mutations(g: torch.fx.Graph) -> set[int]:
55
+ def meta_fk(meta: dict[str, Any]) -> Any:
56
+ return meta["val"] if "val" in meta else meta["fake_result"]
57
+
58
+ inputs = defaultdict(set)
59
+ input_idx = 0
60
+ mutated_inputs = set()
61
+ for n in g.nodes:
62
+ if n.op == "placeholder":
63
+ if isinstance(meta_fk(n.meta), torch.Tensor):
64
+ inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
65
+ input_idx += 1
66
+ elif n.op == "call_function":
67
+ if not hasattr(n.target, "_schema"):
68
+ continue
69
+
70
+ schema = n.target._schema
71
+ for i, arg in enumerate(schema.arguments):
72
+ if i < len(n.args):
73
+ argument = n.args[i]
74
+ else:
75
+ if arg.name not in n.kwargs:
76
+ continue
77
+ argument = n.kwargs[arg.name]
78
+ mut_arg = False
79
+ if arg.alias_info:
80
+ if arg.alias_info.is_write:
81
+ mut_arg = True
82
+ if mut_arg:
83
+ # TODO: not correct for args that contain tensors in a struct
84
+ # like list
85
+ mutated_inputs |= inputs[
86
+ StorageWeakRef(meta_fk(argument.meta)._typed_storage())
87
+ ]
88
+
89
+ # TODO: error on unrecognized nodes
90
+ return mutated_inputs
91
+
92
+
93
+ def get_device_node_mapping(
94
+ gm: torch.fx.GraphModule,
95
+ ) -> dict[torch.device, torch.fx.Node]:
96
+ device_node_mapping: dict[torch.device, torch.fx.Node] = {}
97
+ for n in gm.graph.nodes:
98
+ t = n.meta.get("val", None)
99
+ if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
100
+ device_node_mapping[t.device] = n
101
+ return device_node_mapping
102
+
103
+
104
+ def check_for_mutation_ignore_cuda_graph_managed_tensor(
105
+ aot_model: torch.fx.GraphModule, num_fixed: int
106
+ ) -> Optional[str]:
107
+ mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
108
+ if not mutation_indices:
109
+ return None
110
+
111
+ placeholders = get_placeholder_info(aot_model.graph)
112
+ return get_mutation_stack_trace(placeholders, mutation_indices)
113
+
114
+
115
+ def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
116
+ if not config.cudagraph_backend_support_input_mutation:
117
+ if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
118
+ aot_model, num_fixed
119
+ ):
120
+ return mut_skip
121
+
122
+ if skip := check_multiple_devices_or_any_cpu_nodes(
123
+ get_device_node_mapping(aot_model)
124
+ ):
125
+ return skip
126
+
127
+ if node := get_first_incompatible_cudagraph_node(aot_model):
128
+ return format_default_skip_message(f"incompatible op ({node.name})")
129
+
130
+ return None
131
+
132
+
133
+ def get_device_index(gm: torch.fx.GraphModule) -> int:
134
+ device = next(iter(get_device_node_mapping(gm)))
135
+ assert device.type == "cuda"
136
+ return device.index
137
+
138
+
139
+ def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
140
+ output = output_node(gm)
141
+ assert len(output.args) == 1
142
+ args = output.args[0]
143
+ if not hasattr(args, "__iter__"):
144
+ return []
145
+ return [
146
+ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
147
+ for arg in args # type: ignore[union-attr]
148
+ ]
149
+
150
+
151
+ def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
152
+ from torch._inductor.cudagraph_trees import cudagraphify_impl
153
+
154
+ do_cudagraphs = BoxedBool(True)
155
+ boxed_device_index = BoxedDeviceIndex(None)
156
+
157
+ def forward_cudagraphs(
158
+ aot_model: torch.fx.GraphModule,
159
+ aot_inputs: list[Any],
160
+ is_inference: bool = False,
161
+ ) -> Any:
162
+ interp = boxed_nop(aot_model, aot_inputs)
163
+ fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
164
+ if skip_msg := check_for_skip(aot_model, fixed):
165
+ BoxedBool.disable(do_cudagraphs)
166
+ log_cudagraph_skip_and_bump_counter(
167
+ f"skipping cudagraphs due to {skip_msg}"
168
+ )
169
+ return interp
170
+
171
+ boxed_device_index.set(get_device_index(aot_model))
172
+ out = cudagraphify_impl(
173
+ interp,
174
+ aot_inputs,
175
+ range(fixed),
176
+ device_index=boxed_device_index.value,
177
+ is_backward=False,
178
+ is_inference=False, # Q: should forward is_inference here?
179
+ stack_traces=get_stack_traces(aot_model),
180
+ placeholders=get_placeholder_info(aot_model.graph),
181
+ mutated_input_idxs=find_input_mutations(aot_model.graph),
182
+ )
183
+ out._boxed_call = True # type: ignore[attr-defined]
184
+ return out
185
+
186
+ def backward_cudagraphs(
187
+ aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
188
+ ) -> Any:
189
+ interp = boxed_nop(aot_model, aot_inputs)
190
+ if not do_cudagraphs:
191
+ return aot_model
192
+
193
+ fixed = count_tangents(aot_model)
194
+ if skip_msg := check_for_skip(aot_model, fixed):
195
+ log_cudagraph_skip_and_bump_counter(
196
+ f"skipping cudagraphs due to {skip_msg}"
197
+ )
198
+
199
+ # See [Backward Generation Handling]
200
+ device_idx = boxed_device_index.value
201
+ if device_idx is None:
202
+ device_idx = 0 # Default to device 0 if not set
203
+ manager = torch._inductor.cudagraph_trees.get_manager(
204
+ device_idx, create_if_none_exists=False
205
+ )
206
+ assert manager is not None
207
+
208
+ def fn(inputs: list[Any]) -> Any:
209
+ # pyrefly: ignore [missing-attribute]
210
+ manager.set_to_running_backward()
211
+ return aot_model(inputs)
212
+
213
+ fn._boxed_call = True # type: ignore[attr-defined]
214
+ return fn
215
+
216
+ out = cudagraphify_impl(
217
+ interp,
218
+ aot_inputs,
219
+ range(fixed),
220
+ device_index=get_device_index(aot_model),
221
+ is_backward=True,
222
+ is_inference=False,
223
+ stack_traces=get_stack_traces(aot_model),
224
+ placeholders=get_placeholder_info(aot_model.graph),
225
+ mutated_input_idxs=find_input_mutations(aot_model.graph),
226
+ )
227
+ out._boxed_call = True # type: ignore[attr-defined]
228
+ return out
229
+
230
+ aot_cudagraphs = aot_autograd(
231
+ fw_compiler=forward_cudagraphs,
232
+ bw_compiler=backward_cudagraphs,
233
+ inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
234
+ keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
235
+ )
236
+ return aot_cudagraphs(dynamo_model, dynamo_inputs)
237
+
238
+
239
+ class CudagraphsBackend:
240
+ compiler_name = "cudagraphs"
241
+
242
+ @staticmethod
243
+ def reset() -> None:
244
+ from torch._inductor.cudagraph_trees import reset_cudagraph_trees
245
+
246
+ reset_cudagraph_trees()
247
+
248
+ @staticmethod
249
+ def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
250
+ return cudagraphs(model, inputs)
251
+
252
+
253
+ # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
254
+ # for debugging and can serve as a perf baseline.
255
+ register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
256
+
257
+
258
+ def cudagraphs_inner(
259
+ model: Callable[..., Any],
260
+ inputs: Sequence[Any],
261
+ copy_outputs: bool = True,
262
+ copy_inputs: bool = True,
263
+ ) -> Callable[..., Sequence[Any]]:
264
+ """This isn't registered as a backend, but is used in some benchmarks"""
265
+ assert isinstance(inputs, (list, tuple))
266
+ if copy_inputs:
267
+ static_inputs = [torch.zeros_like(x) for x in inputs]
268
+ else:
269
+ static_inputs = list(inputs)
270
+
271
+ # warmup
272
+ torch.cuda.synchronize()
273
+ stream = torch.cuda.Stream()
274
+ stream.wait_stream(torch.cuda.current_stream())
275
+ with torch.cuda.stream(stream):
276
+ model(*inputs)
277
+ stream.synchronize()
278
+ torch.cuda.current_stream().wait_stream(stream)
279
+ torch.cuda.synchronize()
280
+
281
+ # record
282
+ graph = torch.cuda.CUDAGraph()
283
+ with torch.cuda.graph(graph, stream=stream):
284
+ static_outputs = model(*static_inputs)
285
+ if not isinstance(static_outputs, (list, tuple)):
286
+ static_outputs = (static_outputs,)
287
+
288
+ def run(*new_inputs: Any) -> Sequence[Any]:
289
+ assert len(static_inputs) == len(new_inputs)
290
+ if copy_inputs:
291
+ for dst, src in zip(static_inputs, new_inputs):
292
+ dst.copy_(src)
293
+ graph.replay()
294
+ if copy_outputs:
295
+ return [x.clone() for x in static_outputs]
296
+ else:
297
+ return static_outputs
298
+
299
+ return run
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements distributed training optimizations for TorchDynamo backends.
3
+
4
+ It provides functionality to optimize models wrapped in DistributedDataParallel (DDP)
5
+ by intelligently splitting compiled graphs to align with DDP's gradient synchronization
6
+ boundaries. Key features include:
7
+
8
+ - Graph partitioning based on parameter bucket sizes
9
+ - Optimization of allreduce operations for distributed training
10
+ - Support for parameter ignoring and buffer handling
11
+ - Submodule compilation and management
12
+ - Debugging utilities for distributed training
13
+
14
+ The main component is the DDPOptimizer class, which handles graph splitting and
15
+ recompilation to enable efficient distributed training while maintaining the benefits
16
+ of compilation.
17
+ """
18
+
19
+ import logging
20
+ import traceback
21
+ from collections.abc import Callable
22
+ from dataclasses import dataclass, field
23
+ from typing import Any, Optional, TYPE_CHECKING
24
+ from unittest import mock
25
+
26
+ import torch
27
+ from torch import fx
28
+ from torch._dynamo.backends.registry import CompiledFn, CompilerFn
29
+ from torch._dynamo.output_graph import GraphCompileReason
30
+ from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
31
+ from torch._logging import trace_structured
32
+ from torch.fx.node import Node
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
37
+
38
+
39
+ # Regular log messages should go through 'log'.
40
+ # ddp_graph_log is a separate artifact logger reserved for dumping graphs.
41
+ # See docs/source/logging.rst for more info.
42
+ log = logging.getLogger(__name__)
43
+ ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
44
+
45
+
46
+ def args_str(args: Any) -> str:
47
+ # a debug helper
48
+ if torch.is_tensor(args):
49
+ return f"T[{args.shape}]"
50
+ elif isinstance(args, tuple):
51
+ return f"tuple({', '.join([args_str(x) for x in args])})"
52
+ elif isinstance(args, list):
53
+ return f"list({', '.join([args_str(x) for x in args])})"
54
+ else:
55
+ return str(args)
56
+
57
+
58
+ @dataclass
59
+ class Bucket:
60
+ size: int = 0
61
+ params: list[str] = field(default_factory=list)
62
+ nodes: list[fx.Node] = field(default_factory=list)
63
+
64
+ # param_ids is just used for unit testing
65
+ param_ids: list[int] = field(default_factory=list)
66
+
67
+ # keep track of any buckets that were extended for logging purposes
68
+ opcount_increased_to_capture_external_output: int = 0
69
+ paramsize_before_opcount_increase: int = 0
70
+
71
+
72
+ def bucket_has_external_output(bucket: Bucket) -> bool:
73
+ nodes_in_bucket = set()
74
+ # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
75
+ # so we don't reverse it here
76
+ for node in bucket.nodes:
77
+ # assume node.op != output, since those are filtered in the original iteration
78
+ nodes_in_bucket.add(node)
79
+ for user in node.users:
80
+ if user not in nodes_in_bucket:
81
+ return True
82
+ return False
83
+
84
+
85
+ def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
86
+ headers = ("Index", "Size (b)", "Param Names")
87
+ rows: list[tuple[Optional[int], Optional[int], str]] = []
88
+ extended_buckets = []
89
+ for idx, bucket in enumerate(reversed(buckets)):
90
+ if len(bucket.params) > 0:
91
+ rows.append((idx, bucket.size, bucket.params[0]))
92
+ rows.extend((None, None, param) for param in bucket.params[1:])
93
+ if bucket.opcount_increased_to_capture_external_output > 0:
94
+ extended_buckets.append(
95
+ (
96
+ idx,
97
+ bucket.opcount_increased_to_capture_external_output,
98
+ bucket.size - bucket.paramsize_before_opcount_increase,
99
+ )
100
+ )
101
+
102
+ if rows:
103
+ log.info(
104
+ "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
105
+ bucket_bytes_cap,
106
+ len(buckets),
107
+ )
108
+
109
+ if extended_buckets:
110
+ log.warning(
111
+ "Some buckets were extended beyond their requested parameter capacities"
112
+ " in order to ensure each subgraph has an output node, required for fx graph partitioning."
113
+ " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
114
+ " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
115
+ " partitions for optimal DDP performance."
116
+ )
117
+
118
+ try:
119
+ from tabulate import tabulate
120
+
121
+ log.debug(
122
+ "\nDDPOptimizer produced the following bucket assignments:\n%s",
123
+ tabulate(rows, headers=headers, tablefmt="simple_grid"),
124
+ )
125
+
126
+ if extended_buckets:
127
+ log.warning(
128
+ "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
129
+ tabulate(
130
+ extended_buckets,
131
+ headers=("Index", "Extra Ops", "Extra Param Size (b)"),
132
+ tablefmt="simple_grid",
133
+ ),
134
+ )
135
+ except ImportError:
136
+ log.debug(
137
+ "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
138
+ )
139
+ else:
140
+ log.debug("DDPOptimizer captured no parameters and did not split this graph.")
141
+
142
+
143
+ def has_higher_order_op(gm: fx.GraphModule) -> bool:
144
+ # Check if there is a higher order op in the graph
145
+ for node in gm.graph.nodes:
146
+ if node.op == "get_attr":
147
+ maybe_param = getattr(gm, node.target)
148
+ if isinstance(maybe_param, torch.fx.GraphModule):
149
+ return True
150
+ return False
151
+
152
+
153
+ def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
154
+ for name, module in split_gm.named_modules():
155
+ if "." not in name and len(name):
156
+ # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
157
+ module.meta = orig_gm.meta
158
+ module._param_name_to_source = orig_gm._param_name_to_source
159
+
160
+
161
+ def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
162
+ name_to_dynamo_source = {}
163
+ for node in orig_gm.graph.find_nodes(op="placeholder"):
164
+ name_to_dynamo_source[node.name] = node._dynamo_source
165
+
166
+ for name, module in split_gm.named_modules():
167
+ if "." not in name and len(name):
168
+ for node in module.graph.find_nodes(op="placeholder"):
169
+ # non-placeholder in original_gm may become placeholder in submodules
170
+ node._dynamo_source = name_to_dynamo_source.get(node.name, None)
171
+
172
+
173
+ class DDPOptimizerContext:
174
+ def __init__(self) -> None:
175
+ self.curr_bucket: int = -1
176
+ self.metadata_per_bucket: list[ViewAndMutationMeta] = []
177
+
178
+
179
+ # compile each of the partitioned submodules using the user-provided compiler
180
+ class SubmodCompiler(torch.fx.interpreter.Interpreter):
181
+ def __init__(
182
+ self,
183
+ module: fx.GraphModule,
184
+ compiler: CompilerFn,
185
+ fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
186
+ ) -> None:
187
+ super().__init__(module)
188
+ self.compiler = compiler
189
+ self.fake_mode = fake_mode
190
+ # See Note [DDPOptimizer and fw_metadata]
191
+ ctx = torch._guards.TracingContext.try_get()
192
+ if ctx is not None:
193
+ ctx.ddp_optimizer_ctx = DDPOptimizerContext()
194
+
195
+ def compile_submod(
196
+ self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
197
+ ) -> Any:
198
+ """
199
+ Compile the submodule,
200
+ using a wrapper to make sure its output is always a tuple,
201
+ which is required by AotAutograd based compilers
202
+ """
203
+ assert len(kwargs) == 0, "We assume only args for these modules"
204
+
205
+ class WrapperModule(torch.nn.Module):
206
+ def __init__(
207
+ self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
208
+ ) -> None:
209
+ super().__init__()
210
+ self.submod = submod
211
+ self.unwrap_singleton_tuple = unwrap_singleton_tuple
212
+
213
+ def forward(self, *args: Any) -> Any:
214
+ x = self.submod(*args)
215
+ # TODO(whc)
216
+ # for some reason the isinstance check is necessary if I split one node per submod
217
+ # - even though I supposedly wrapped the output in a tuple in those cases, the real
218
+ # compiled module was still returning a tensor
219
+ if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
220
+ return x[0]
221
+ return x
222
+
223
+ unwrap_singleton_tuple = False
224
+ for sn in input_mod.graph.nodes:
225
+ if sn.op == "output":
226
+ if not isinstance(sn.args[0], tuple):
227
+ unwrap_singleton_tuple = True
228
+ sn.args = (sn.args,)
229
+
230
+ input_mod.recompile()
231
+ input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
232
+ "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
233
+ " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
234
+ [
235
+ # it's close to useless to get a real stacktrace here, and quite verbose.
236
+ traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
237
+ ],
238
+ )
239
+
240
+ wrapper = WrapperModule(
241
+ self.compiler(input_mod, args),
242
+ unwrap_singleton_tuple,
243
+ )
244
+ return wrapper
245
+
246
+ # Note:
247
+ #
248
+ # The way distributed works today around fake tensors can be somewhat confusing.
249
+ # Some of these codepaths are shared in both runtime, and compile time. The presence
250
+ # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
251
+ #
252
+ # A few things to keep in mind:
253
+ #
254
+ # 1) We invoke `compile_submod` with a real module. The output of that gets stored
255
+ # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
256
+ #
257
+ # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
258
+ # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
259
+ #
260
+ # 3) Fake tensors should always be around during compile time.
261
+ #
262
+ # 4) Fake tensors should never be around at runtime.
263
+ #
264
+ # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
265
+ # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
266
+ def run_node(self, n: Node) -> Any:
267
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
268
+ new_args = []
269
+ assert self.fake_mode
270
+ for arg in args:
271
+ if isinstance(arg, torch.Tensor) and not isinstance(
272
+ arg, torch._subclasses.FakeTensor
273
+ ):
274
+ new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
275
+ else:
276
+ new_args.append(arg)
277
+
278
+ log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
279
+ assert isinstance(args, tuple)
280
+ assert isinstance(kwargs, dict)
281
+
282
+ if n.op == "call_module":
283
+ real_mod = self.fetch_attr(str(n.target))
284
+ if self.fake_mode:
285
+ curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
286
+ else:
287
+ curr_submod = real_mod
288
+
289
+ ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
290
+
291
+ # When calling the compiler on the submod, inputs (new_args) are expected to
292
+ # be FakeTensors already since Dynamo would have made them FakeTensors in the
293
+ # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
294
+ # since this wrapping happens during compilation
295
+
296
+ # Note: Returning Fake Tensors on First AOT Autograd Call
297
+ #
298
+ # Inductor will optimize strides of outputs when it deems it profitable.
299
+ # For instance, converting to channels last. When we split the graph here
300
+ # into multiple inductor compilations, we need to make sure that the
301
+ # output strides of one compilation is appropriately passed to the subsequent
302
+ # compilations. However, the mapping from inductor output to dynamo output
303
+ # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
304
+ # subclass handling, etc. In order to replay all this logic we set a flag such that
305
+ # the first invocation of inductor in aot_autograd will return Fake Tensors with
306
+ # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
307
+ # This gives us the appropriately strided outputs here which will reflect runtime strides.
308
+
309
+ class FakeifyFirstAOTInvocationGuard:
310
+ def __init__(self) -> None:
311
+ self.tc = torch._guards.TracingContext.try_get()
312
+ assert self.tc
313
+ self.tc.fakify_first_call = True
314
+
315
+ def __del__(self) -> None:
316
+ self.tc.fakify_first_call = False # type: ignore[union-attr]
317
+
318
+ # For aot_eager and other backends, tracing context is not set
319
+ has_tracing_context = torch._guards.TracingContext.try_get() is not None
320
+ if has_tracing_context:
321
+ g = FakeifyFirstAOTInvocationGuard() # noqa: F841
322
+
323
+ from torch._dynamo.utils import counters
324
+
325
+ init = counters["aot_autograd"]["total"]
326
+ compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
327
+
328
+ # TODO - better way of doing this?
329
+ # Only aot autograd handles fakifying first call
330
+ invoked_aot_autograd = init != counters["aot_autograd"]["total"]
331
+
332
+ # We update the original (outer) graph with a call into the compiled module
333
+ # instead of the uncompiled one.
334
+ self.module.delete_submodule(n.target) # type: ignore[operator]
335
+ n.target = "compiled_" + n.target # type: ignore[operator]
336
+ self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
337
+
338
+ # Finally, we have to produce inputs for use compiling the next submodule,
339
+ # and these need to be FakeTensors, so we execute the module under fake_mode
340
+ # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
341
+ with (
342
+ self.fake_mode,
343
+ mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True),
344
+ ):
345
+ if has_tracing_context and invoked_aot_autograd:
346
+ tracing_ctx = torch._guards.TracingContext.try_get()
347
+ assert tracing_ctx is not None
348
+ # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs
349
+ # Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances
350
+ ddp_ctx = tracing_ctx.ddp_optimizer_ctx
351
+ assert ddp_ctx is not None
352
+ assert tracing_ctx.fw_metadata is not None
353
+ ddp_ctx.curr_bucket += 1
354
+ ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata)
355
+
356
+ out = compiled_submod_real(*new_args, **kwargs)
357
+ # output should be fake or subclass
358
+ assert all(
359
+ (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
360
+ for t in (out if isinstance(out, (list, tuple)) else [out])
361
+ )
362
+ return out
363
+ else:
364
+ return curr_submod(*new_args, **kwargs)
365
+ else:
366
+ # placeholder or output nodes don't need to get compiled, just executed
367
+ return getattr(self, n.op)(n.target, new_args, kwargs)
368
+
369
+
370
+ class DDPOptimizer:
371
+ """Note [DDPOptimizer]
372
+ DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
373
+ breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
374
+ the boundaries of gradient-allreduce buckets chosen by DDP.
375
+
376
+ Background/Motivation
377
+ - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
378
+ - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
379
+ - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
380
+ at around the same time during backward and thus can share the same allreduce efficiently
381
+ - Allreduces must overlap with backward compute for optimal training performance
382
+ - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
383
+ operates when individual grads become 'ready'
384
+ - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
385
+ autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
386
+ fused backward function executes, preventing any overlap of compute and communication
387
+
388
+ Algorithm
389
+ - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
390
+ this graph in reverse order to determine the true order that gradients will become ready during backward.
391
+ - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
392
+ and a graph break introduced
393
+ - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
394
+ into an outer module that is returned to the user
395
+
396
+ Notes
397
+ - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
398
+ and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
399
+ in eager.
400
+ - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
401
+ produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
402
+ degradation approaching the baseline case where graph-splits are not used, but not worse.
403
+ - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
404
+ subgraphs being compiled
405
+ - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
406
+ left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
407
+ also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
408
+ it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
409
+ - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
410
+ and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
411
+ DDPOptimizer)
412
+
413
+ Debugging
414
+ - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
415
+ - In many cases, the log messages are helpful (they show bucket size assignments)-
416
+ just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
417
+ - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
418
+ in a single process (or with torchrun, in multiple processes)
419
+
420
+ Args:
421
+ bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
422
+ set to match the equivalent parameter on the original DDP module.
423
+
424
+ backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
425
+
426
+ first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
427
+ special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
428
+
429
+ """
430
+
431
+ def __init__(
432
+ self,
433
+ bucket_bytes_cap: int,
434
+ backend_compile_fn: CompilerFn,
435
+ first_bucket_cap: Optional[int] = None,
436
+ ) -> None:
437
+ if first_bucket_cap is not None:
438
+ self.first_bucket_cap = first_bucket_cap
439
+ elif torch.distributed.is_available():
440
+ # this constant comes from C10D lib which is not always built
441
+ self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
442
+ else:
443
+ self.first_bucket_cap = bucket_bytes_cap
444
+
445
+ self.bucket_bytes_cap = bucket_bytes_cap
446
+ assert self.first_bucket_cap <= self.bucket_bytes_cap, (
447
+ "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
448
+ )
449
+
450
+ self.backend_compile_fn = backend_compile_fn
451
+
452
+ def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
453
+ return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
454
+
455
+ def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
456
+ bucket.size += param.untyped_storage().nbytes()
457
+ bucket.params.append(name)
458
+ bucket.param_ids.append(id(param))
459
+
460
+ def add_module_params_to_bucket(
461
+ self,
462
+ mod: torch.nn.Module,
463
+ bucket: Bucket,
464
+ processed_modules: set[torch.nn.Module],
465
+ prefix: str,
466
+ ) -> None:
467
+ processed_modules.add(mod)
468
+ for name, param in mod.named_parameters():
469
+ if param.requires_grad and not self._ignore_parameter(param):
470
+ self.add_param(bucket, param, f"{prefix}_{name}")
471
+
472
+ def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
473
+ for arg in node.args:
474
+ if not isinstance(arg, torch.fx.node.Node):
475
+ continue
476
+ if arg.op != "placeholder":
477
+ continue
478
+ param = arg.meta["example_value"]
479
+ if (
480
+ isinstance(param, torch.nn.Parameter)
481
+ and param.requires_grad
482
+ and not self._ignore_parameter(param)
483
+ ):
484
+ self.add_param(bucket, param, str(arg.target))
485
+
486
+ def compile_fn(
487
+ self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
488
+ ) -> CompiledFn:
489
+ """
490
+ Implements graph splitting, first determining a set of of buckets by counting
491
+ parameter sizes in reverse graph order, then invoking the user/backend compiler
492
+ to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
493
+ and returns its callable.
494
+ """
495
+ # 1: compute the partition map according to DDP bucket logic
496
+ buckets = [Bucket()] # (size, param_names)
497
+ processed_modules: set[torch.nn.Module] = set()
498
+ for node in reversed(gm.graph.nodes):
499
+ if node.op in ("output", "placeholder"):
500
+ continue
501
+
502
+ if (
503
+ buckets[0].size >= self.bucket_bytes_cap
504
+ or len(buckets) == 1
505
+ and buckets[0].size >= self.first_bucket_cap
506
+ ):
507
+ if bucket_has_external_output(buckets[0]):
508
+ buckets.insert(0, Bucket())
509
+ else:
510
+ # continue building this bucket past the point of filling its parameter capacity,
511
+ # to increase chances it contains at least one node that is either a global output or
512
+ # passed as input to a subsequent graph
513
+
514
+ if buckets[0].opcount_increased_to_capture_external_output == 0:
515
+ buckets[0].paramsize_before_opcount_increase = buckets[0].size
516
+ buckets[0].opcount_increased_to_capture_external_output += 1
517
+
518
+ if node.op == "call_function":
519
+ self.add_param_args(buckets[0], node)
520
+
521
+ elif node.op == "call_module":
522
+ target_mod = gm.get_submodule(node.target)
523
+ if target_mod not in processed_modules:
524
+ self.add_module_params_to_bucket(
525
+ target_mod, buckets[0], processed_modules, node.target
526
+ )
527
+ elif node.op == "call_method":
528
+ if isinstance(node.args[0].target, str):
529
+ target_mod = None
530
+ try:
531
+ target_mod = gm.get_submodule(node.args[0].target)
532
+ except AttributeError:
533
+ pass
534
+ if target_mod is not None and target_mod not in processed_modules:
535
+ self.add_module_params_to_bucket(
536
+ target_mod, buckets[0], processed_modules, node.target
537
+ )
538
+ # This handles situations like tmp = torch.mm(x, self.weight.t())
539
+ # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
540
+ # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
541
+ self.add_param_args(buckets[0], node)
542
+
543
+ elif node.op == "get_attr":
544
+ maybe_param = getattr(gm, node.target)
545
+ if (
546
+ isinstance(maybe_param, torch.nn.Parameter)
547
+ and maybe_param.requires_grad
548
+ and not self._ignore_parameter(maybe_param)
549
+ ):
550
+ self.add_param(buckets[0], maybe_param, node.target)
551
+
552
+ # All nodes have to be mapped to a bucket, even if they don't have their own params
553
+ # Ignored params still end up in buckets, we just don't count them towards the capacity
554
+ buckets[0].nodes.append(node)
555
+
556
+ if len(buckets) > 1 and buckets[0].size == 0:
557
+ # we collected a small preamble graph with ops that don't include parameters, fuse it back
558
+ buckets[1].nodes.extend(buckets[0].nodes)
559
+ assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
560
+ del buckets[0]
561
+
562
+ # stash buckets for testing/debugging purposes
563
+ self.buckets = buckets
564
+ pretty_print_buckets(buckets, self.bucket_bytes_cap)
565
+
566
+ if len(buckets) == 1:
567
+ # bypass split/fuse logic if there is only one bucket
568
+ return self.backend_compile_fn(gm, example_inputs)
569
+
570
+ # 2: partition the graphmodule according to bucket capacity
571
+ partition_map = {}
572
+ for idx, b in enumerate(buckets):
573
+ for node in b.nodes:
574
+ partition_map[node] = idx
575
+
576
+ split_gm = fx.passes.split_module.split_module(
577
+ gm,
578
+ None, # type: ignore[arg-type]
579
+ lambda node: partition_map[node],
580
+ )
581
+
582
+ # See note [Assumption on Dynamo Metadata]
583
+ propagate_dynamo_source(gm, split_gm)
584
+ propagate_metadata(gm, split_gm)
585
+
586
+ debug_str = (
587
+ f"\n---orig graph---\n{gm.graph}\n"
588
+ + f"\n---split graph---\n{split_gm.graph}\n"
589
+ )
590
+ for name, module in split_gm.named_modules():
591
+ if "." not in name and len(name):
592
+ # only print the submod graphs, not their children
593
+ debug_str += f"\n---{name} graph---\n{module.graph}\n"
594
+ debug_str += "\n---------------\n"
595
+ ddp_graph_log.debug(debug_str)
596
+
597
+ trace_structured(
598
+ "optimize_ddp_split_graph",
599
+ payload_fn=lambda: split_gm.print_readable(print_output=False),
600
+ )
601
+ for name, module in split_gm.named_modules():
602
+ if "." not in name and len(name):
603
+ trace_structured(
604
+ "optimize_ddp_split_child",
605
+ lambda: {"name": name},
606
+ payload_fn=lambda: module.print_readable(print_output=False),
607
+ )
608
+
609
+ fake_mode = detect_fake_mode(example_inputs)
610
+ if fake_mode is None:
611
+ fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
612
+
613
+ submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
614
+ with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
615
+ submod_compiler.run(*example_inputs)
616
+ split_gm.recompile()
617
+
618
+ ddp_graph_log.debug(
619
+ "\n---final graph---\n%s\n---------------\n", split_gm.graph
620
+ )
621
+ return split_gm
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This backend is maintained by ONNX team. To direct issues
2
+ # to the right people, please tag related GitHub issues with `module: onnx`.
3
+ #
4
+ # Maintainers' Github IDs: wschin, xadupre
5
+ # from torch.onnx._internal.onnxruntime import (
6
+ # is_onnxrt_backend_supported,
7
+ # torch_compile_backend,
8
+ # )
9
+
10
+ # from .registry import register_backend
11
+
12
+ """
13
+ Placeholder for onnxruntime backend for dynamo
14
+ """
15
+
16
+ # def has_onnxruntime():
17
+ # # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported()
18
+ # return is_onnxrt_backend_supported()
19
+
20
+
21
+ # if is_onnxrt_backend_supported():
22
+ # register_backend(name="onnxrt", compiler_fn=torch_compile_backend)
23
+ # else:
24
+
25
+ # def information_displaying_backend(*args, **kwargs):
26
+ # raise ImportError(
27
+ # "onnxrt is not registered as a backend. "
28
+ # "Please make sure all dependencies such as "
29
+ # "numpy, onnx, onnxscript, and onnxruntime-training are installed. "
30
+ # "Suggested procedure to fix dependency problem:\n"
31
+ # " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n"
32
+ # " (2) Open a new python terminal.\n"
33
+ # " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n"
34
+ # " (4) If it returns `True`, then you can use `onnxrt` backend.\n"
35
+ # " (5) If it returns `False`, please execute the package importing section in "
36
+ # "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
37
+ # )
38
+
39
+ # register_backend(name="onnxrt", compiler_fn=information_displaying_backend)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements TorchDynamo's backend registry system for managing compiler backends.
3
+
4
+ The registry provides a centralized way to register, discover and manage different compiler
5
+ backends that can be used with torch.compile(). It handles:
6
+
7
+ - Backend registration and discovery through decorators and entry points
8
+ - Lazy loading of backend implementations
9
+ - Lookup and validation of backend names
10
+ - Categorization of backends using tags (debug, experimental, etc.)
11
+
12
+ Key components:
13
+ - CompilerFn: Type for backend compiler functions that transform FX graphs
14
+ - _BACKENDS: Registry mapping backend names to entry points
15
+ - _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
16
+
17
+ Example usage:
18
+ @register_backend
19
+ def my_compiler(fx_graph, example_inputs):
20
+ # Transform FX graph into optimized implementation
21
+ return compiled_fn
22
+
23
+ # Use registered backend
24
+ torch.compile(model, backend="my_compiler")
25
+
26
+ The registry also supports discovering backends through setuptools entry points
27
+ in the "torch_dynamo_backends" group. Example:
28
+ ```
29
+ setup.py
30
+ ---
31
+ from setuptools import setup
32
+
33
+ setup(
34
+ name='my_torch_backend',
35
+ version='0.1',
36
+ packages=['my_torch_backend'],
37
+ entry_points={
38
+ 'torch_dynamo_backends': [
39
+ # name = path to entry point of backend implementation
40
+ 'my_compiler = my_torch_backend.compiler:my_compiler_function',
41
+ ],
42
+ },
43
+ )
44
+ ```
45
+ ```
46
+ my_torch_backend/compiler.py
47
+ ---
48
+ def my_compiler_function(fx_graph, example_inputs):
49
+ # Transform FX graph into optimized implementation
50
+ return compiled_fn
51
+ ```
52
+ Using `my_compiler` backend:
53
+ ```
54
+ import torch
55
+
56
+ model = ... # Your PyTorch model
57
+ optimized_model = torch.compile(model, backend="my_compiler")
58
+ ```
59
+ """
60
+
61
+ import functools
62
+ import logging
63
+ from collections.abc import Callable, Sequence
64
+ from importlib.metadata import EntryPoint
65
+ from typing import Any, Optional, Protocol, Union
66
+
67
+ import torch
68
+ from torch import fx
69
+
70
+
71
+ log = logging.getLogger(__name__)
72
+
73
+
74
+ class CompiledFn(Protocol):
75
+ def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...
76
+
77
+
78
+ CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]
79
+
80
+ _BACKENDS: dict[str, Optional[EntryPoint]] = {}
81
+ _COMPILER_FNS: dict[str, CompilerFn] = {}
82
+
83
+
84
+ def register_backend(
85
+ compiler_fn: Optional[CompilerFn] = None,
86
+ name: Optional[str] = None,
87
+ tags: Sequence[str] = (),
88
+ ) -> Callable[..., Any]:
89
+ """
90
+ Decorator to add a given compiler to the registry to allow calling
91
+ `torch.compile` with string shorthand. Note: for projects not
92
+ imported by default, it might be easier to pass a function directly
93
+ as a backend and not use a string.
94
+
95
+ Args:
96
+ compiler_fn: Callable taking a FX graph and fake tensor inputs
97
+ name: Optional name, defaults to `compiler_fn.__name__`
98
+ tags: Optional set of string tags to categorize backend with
99
+ """
100
+ if compiler_fn is None:
101
+ # @register_backend(name="") syntax
102
+ return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
103
+ assert callable(compiler_fn)
104
+ name = name or compiler_fn.__name__
105
+ assert name not in _COMPILER_FNS, f"duplicate name: {name}"
106
+ if compiler_fn not in _BACKENDS:
107
+ _BACKENDS[name] = None
108
+ _COMPILER_FNS[name] = compiler_fn
109
+ compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
110
+ return compiler_fn
111
+
112
+
113
+ register_debug_backend = functools.partial(register_backend, tags=("debug",))
114
+ register_experimental_backend = functools.partial(
115
+ register_backend, tags=("experimental",)
116
+ )
117
+
118
+
119
+ def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
120
+ """Expand backend strings to functions"""
121
+ if isinstance(compiler_fn, str):
122
+ if compiler_fn not in _BACKENDS:
123
+ _lazy_import()
124
+ if compiler_fn not in _BACKENDS:
125
+ from ..exc import InvalidBackend
126
+
127
+ raise InvalidBackend(name=compiler_fn)
128
+
129
+ if compiler_fn not in _COMPILER_FNS:
130
+ entry_point = _BACKENDS[compiler_fn]
131
+ if entry_point is not None:
132
+ register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
133
+ compiler_fn = _COMPILER_FNS[compiler_fn]
134
+ return compiler_fn
135
+
136
+
137
+ # NOTE: can't type this due to public api mismatch; follow up with dev team
138
+ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def]
139
+ """
140
+ Return valid strings that can be passed to:
141
+
142
+ torch.compile(..., backend="name")
143
+ """
144
+ _lazy_import()
145
+ exclude_tags_set = set(exclude_tags or ())
146
+
147
+ backends = [
148
+ name
149
+ for name in _BACKENDS
150
+ if name not in _COMPILER_FNS
151
+ or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
152
+ ]
153
+ return sorted(backends)
154
+
155
+
156
+ @functools.cache
157
+ def _lazy_import() -> None:
158
+ from .. import backends
159
+ from ..utils import import_submodule
160
+
161
+ import_submodule(backends)
162
+
163
+ from ..repro.after_dynamo import dynamo_minifier_backend
164
+
165
+ assert dynamo_minifier_backend is not None
166
+
167
+ _discover_entrypoint_backends()
168
+
169
+
170
+ @functools.cache
171
+ def _discover_entrypoint_backends() -> None:
172
+ # importing here so it will pick up the mocked version in test_backends.py
173
+ from importlib.metadata import entry_points
174
+
175
+ group_name = "torch_dynamo_backends"
176
+ eps = entry_points(group=group_name)
177
+ eps_dict = {name: eps[name] for name in eps.names}
178
+ for backend_name in eps_dict:
179
+ _BACKENDS[backend_name] = eps_dict[backend_name]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+
5
+ import torch
6
+ from functorch.compile import make_boxed_func
7
+ from torch import fx
8
+
9
+ from ..backends.common import aot_autograd
10
+ from .registry import CompiledFn, register_backend, register_experimental_backend
11
+
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @register_experimental_backend
17
+ def openxla_eval(
18
+ model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
19
+ ) -> CompiledFn:
20
+ return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
21
+
22
+
23
+ def openxla_eval_boxed(
24
+ model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
25
+ ) -> Callable[..., Any]:
26
+ return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
27
+
28
+
29
+ def xla_backend_helper(
30
+ model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False
31
+ ) -> Callable[..., Any]:
32
+ try:
33
+ import torch_xla.core.dynamo_bridge as bridge
34
+ except ImportError as e:
35
+ raise ImportError(
36
+ "Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
37
+ ) from e
38
+
39
+ compiled_graph = None
40
+
41
+ def fwd(*args: torch.Tensor) -> Any:
42
+ nonlocal model
43
+ nonlocal compiled_graph
44
+ if compiled_graph is None:
45
+ compiled_graph = bridge.extract_compiled_graph(model, args)
46
+ del model
47
+ return compiled_graph(*args)
48
+
49
+ return make_boxed_func(fwd) if boxed else fwd
50
+
51
+
52
+ openxla = aot_autograd(
53
+ fw_compiler=openxla_eval_boxed,
54
+ )
55
+ register_backend(name="openxla", compiler_fn=openxla)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (8.73 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc ADDED
Binary file (78.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc ADDED
Binary file (2.42 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc ADDED
Binary file (47.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc ADDED
Binary file (26 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc ADDED
Binary file (6.17 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc ADDED
Binary file (73 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc ADDED
Binary file (27.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc ADDED
Binary file (16.3 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ import re
4
+ import string
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any, Optional
8
+ from types import ModuleType
9
+
10
+ import torch
11
+
12
+ _TAGS: dict[str, dict[str, Any]] = {
13
+ "torch": {
14
+ "cond": {},
15
+ "dynamic-shape": {},
16
+ "escape-hatch": {},
17
+ "map": {},
18
+ "dynamic-value": {},
19
+ "operator": {},
20
+ "mutation": {},
21
+ },
22
+ "python": {
23
+ "assert": {},
24
+ "builtin": {},
25
+ "closure": {},
26
+ "context-manager": {},
27
+ "control-flow": {},
28
+ "data-structure": {},
29
+ "standard-library": {},
30
+ "object-model": {},
31
+ },
32
+ }
33
+
34
+
35
+ class SupportLevel(Enum):
36
+ """
37
+ Indicates at what stage the feature
38
+ used in the example is handled in export.
39
+ """
40
+
41
+ SUPPORTED = 1
42
+ NOT_SUPPORTED_YET = 0
43
+
44
+
45
+ ArgsType = tuple[Any, ...]
46
+
47
+
48
+ def check_inputs_type(args, kwargs):
49
+ if not isinstance(args, tuple):
50
+ raise ValueError(
51
+ f"Expecting args type to be a tuple, got: {type(args)}"
52
+ )
53
+ if not isinstance(kwargs, dict):
54
+ raise ValueError(
55
+ f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
56
+ )
57
+ for key in kwargs:
58
+ if not isinstance(key, str):
59
+ raise ValueError(
60
+ f"Expecting kwargs keys to be a string, got: {type(key)}"
61
+ )
62
+
63
+ def _validate_tag(tag: str):
64
+ parts = tag.split(".")
65
+ t = _TAGS
66
+ for part in parts:
67
+ assert set(part) <= set(
68
+ string.ascii_lowercase + "-"
69
+ ), f"Tag contains invalid characters: {part}"
70
+ if part in t:
71
+ t = t[part]
72
+ else:
73
+ raise ValueError(f"Tag {tag} is not found in registered tags.")
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class ExportCase:
78
+ example_args: ArgsType
79
+ description: str # A description of the use case.
80
+ model: torch.nn.Module
81
+ name: str
82
+ example_kwargs: dict[str, Any] = field(default_factory=dict)
83
+ extra_args: Optional[ArgsType] = None # For testing graph generalization.
84
+ # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
85
+ tags: set[str] = field(default_factory=set)
86
+ support_level: SupportLevel = SupportLevel.SUPPORTED
87
+ dynamic_shapes: Optional[dict[str, Any]] = None
88
+
89
+ def __post_init__(self):
90
+ check_inputs_type(self.example_args, self.example_kwargs)
91
+ if self.extra_args is not None:
92
+ check_inputs_type(self.extra_args, {})
93
+
94
+ for tag in self.tags:
95
+ _validate_tag(tag)
96
+
97
+ if not isinstance(self.description, str) or len(self.description) == 0:
98
+ raise ValueError(f'Invalid description: "{self.description}"')
99
+
100
+
101
+ _EXAMPLE_CASES: dict[str, ExportCase] = {}
102
+ _MODULES: set[ModuleType] = set()
103
+ _EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
104
+ _EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
105
+
106
+
107
+ def register_db_case(case: ExportCase) -> None:
108
+ """
109
+ Registers a user provided ExportCase into example bank.
110
+ """
111
+ if case.name in _EXAMPLE_CASES:
112
+ if case.name not in _EXAMPLE_CONFLICT_CASES:
113
+ _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
114
+ _EXAMPLE_CONFLICT_CASES[case.name].append(case)
115
+ return
116
+
117
+ _EXAMPLE_CASES[case.name] = case
118
+
119
+
120
+ def to_snake_case(name):
121
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
122
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
123
+
124
+
125
+ def _make_export_case(m, name, configs):
126
+ if not isinstance(m, torch.nn.Module):
127
+ raise TypeError("Export case class should be a torch.nn.Module.")
128
+
129
+ if "description" not in configs:
130
+ # Fallback to docstring if description is missing.
131
+ assert (
132
+ m.__doc__ is not None
133
+ ), f"Could not find description or docstring for export case: {m}"
134
+ configs = {**configs, "description": m.__doc__}
135
+ # pyrefly: ignore [bad-argument-type]
136
+ return ExportCase(**{**configs, "model": m, "name": name})
137
+
138
+
139
+ def export_case(**kwargs):
140
+ """
141
+ Decorator for registering a user provided case into example bank.
142
+ """
143
+
144
+ def wrapper(m):
145
+ configs = kwargs
146
+ module = inspect.getmodule(m)
147
+ if module in _MODULES:
148
+ raise RuntimeError("export_case should only be used once per example file.")
149
+
150
+ assert module is not None
151
+ _MODULES.add(module)
152
+ module_name = module.__name__.split(".")[-1]
153
+ case = _make_export_case(m, module_name, configs)
154
+ register_db_case(case)
155
+ return case
156
+
157
+ return wrapper
158
+
159
+
160
+ def export_rewrite_case(**kwargs):
161
+ def wrapper(m):
162
+ configs = kwargs
163
+
164
+ parent = configs.pop("parent")
165
+ assert isinstance(parent, ExportCase)
166
+ key = parent.name
167
+ if key not in _EXAMPLE_REWRITE_CASES:
168
+ _EXAMPLE_REWRITE_CASES[key] = []
169
+
170
+ configs["example_args"] = parent.example_args
171
+ case = _make_export_case(m, to_snake_case(m.__name__), configs)
172
+ _EXAMPLE_REWRITE_CASES[key].append(case)
173
+ return case
174
+
175
+ return wrapper
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch._export.db.examples as examples
5
+
6
+ TEMPLATE = '''import torch
7
+
8
+ def {case_name}(x):
9
+ """
10
+ """
11
+
12
+ return
13
+ '''
14
+
15
+ if __name__ == "__main__":
16
+ assert len(sys.argv) == 2
17
+ root_dir = examples.__name__.replace(".", "/")
18
+ assert os.path.exists(root_dir)
19
+ with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f:
20
+ print("Writing to", f.name, "...")
21
+ f.write(TEMPLATE.format(case_name=sys.argv[1]))
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ def exportdb_error_message(case_name: str) -> str:
4
+ from .examples import all_examples
5
+ from torch._utils_internal import log_export_usage
6
+
7
+ ALL_EXAMPLES = all_examples()
8
+ # Detect whether case_name is really registered in exportdb.
9
+ if case_name in ALL_EXAMPLES:
10
+ url_case_name = case_name.replace("_", "-")
11
+ return f"See {case_name} in exportdb for unsupported case. \
12
+ https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}"
13
+ else:
14
+ log_export_usage(
15
+ event="export.error.casenotregistered",
16
+ message=case_name,
17
+ )
18
+ return f"{case_name} is unsupported."
19
+
20
+
21
+ def get_class_if_classified_error(e: Exception) -> Optional[str]:
22
+ """
23
+ Returns a string case name if the export error e is classified.
24
+ Returns None otherwise.
25
+ """
26
+
27
+ from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError
28
+
29
+ ALWAYS_CLASSIFIED = "always_classified"
30
+ DEFAULT_CLASS_SIGIL = "case_name"
31
+
32
+ # add error types that should be classified, along with any attribute name
33
+ # whose presence acts like a sigil to further distinguish which errors of
34
+ # that type should be classified. If the attribute name is None, then the
35
+ # error type is always classified.
36
+ _ALLOW_LIST = {
37
+ Unsupported: DEFAULT_CLASS_SIGIL,
38
+ UserError: DEFAULT_CLASS_SIGIL,
39
+ TorchRuntimeError: None,
40
+ }
41
+ if type(e) in _ALLOW_LIST:
42
+ # pyrefly: ignore [index-error]
43
+ attr_name = _ALLOW_LIST[type(e)]
44
+ if attr_name is None:
45
+ return ALWAYS_CLASSIFIED
46
+ return getattr(e, attr_name, None)
47
+ return None
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+
4
+ NodeMetadataValue = Any
5
+
6
+
7
+ PROTECTED_KEYS: set[str] = {
8
+ "val",
9
+ "stack_trace",
10
+ "nn_module_stack",
11
+ "debug_handle",
12
+ "tensor_meta",
13
+ }
14
+
15
+
16
+ class NodeMetadata:
17
+ def __init__(self, data: dict[str, Any]) -> None:
18
+ self.data: dict[str, Any] = data.copy()
19
+
20
+ def __getitem__(self, key: str) -> NodeMetadataValue:
21
+ return self.data[key]
22
+
23
+ def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
24
+ if key in PROTECTED_KEYS:
25
+ raise RuntimeError(f"Could not override node key: {key}")
26
+ self.data[key] = value
27
+
28
+ def __contains__(self, key: str) -> bool:
29
+ return key in self.data
30
+
31
+ def copy(self) -> "NodeMetadata":
32
+ return NodeMetadata(self.data.copy())
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-strict
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import Generic, TypeVar, Union
4
+
5
+ import torch
6
+
7
+
8
+ _T = TypeVar("_T")
9
+
10
+
11
+ class ProxyValue(Generic[_T]):
12
+ # pyre-ignore
13
+ def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]):
14
+ # pyre-ignore
15
+ self.data = data
16
+ self.proxy_or_node = proxy
17
+
18
+ @property
19
+ def node(self) -> torch.fx.Node:
20
+ if isinstance(self.proxy_or_node, torch.fx.Node):
21
+ return self.proxy_or_node
22
+ assert isinstance(self.proxy_or_node, torch.fx.Proxy)
23
+ return self.proxy_or_node.node
24
+
25
+ @property
26
+ def proxy(self) -> torch.fx.Proxy:
27
+ if not isinstance(self.proxy_or_node, torch.fx.Proxy):
28
+ raise RuntimeError(
29
+ f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
30
+ )
31
+ return self.proxy_or_node
32
+
33
+ def to_tensor(self) -> torch.Tensor:
34
+ assert isinstance(self.data, torch.Tensor)
35
+ return self.data
36
+
37
+ def is_tensor(self) -> bool:
38
+ return isinstance(self.data, torch.Tensor)
39
+
40
+ # pyre-ignore
41
+ def __iter__(self) -> Iterator[_T]:
42
+ yield from self.data
43
+
44
+ def __bool__(self) -> bool:
45
+ return bool(self.data)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (327 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc ADDED
Binary file (6.99 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc ADDED
Binary file (14.9 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc ADDED
Binary file (4.92 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc ADDED
Binary file (5.09 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc ADDED
Binary file (2.27 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc ADDED
Binary file (8.49 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc ADDED
Binary file (6.05 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc ADDED
Binary file (3.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc ADDED
Binary file (8.68 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import contextlib
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree
7
+ from torch._dispatch.python import enable_python_dispatcher
8
+ from torch._subclasses.fake_tensor import FakeTensorMode
9
+ from torch.fx.graph_module import GraphModule
10
+
11
+
12
+ _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
13
+
14
+
15
+ def _node_metadata_hook(
16
+ node: torch.fx.Node,
17
+ metadata: Optional[dict[str, Any]] = None,
18
+ fake_mode: Optional[FakeTensorMode] = None,
19
+ ) -> None:
20
+ """
21
+ Hook for adding the appropriate metadata to nodes that are created during a
22
+ pass using graph.create_node. An example of how to use it:
23
+
24
+ ```
25
+ with _set_node_metadata_hook(gm,
26
+ functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"})
27
+ ):
28
+ pass(gm)
29
+ ```
30
+
31
+ This hook should not work for all generic cases -- specifically it assumes
32
+ that nodes being added are only call_function nodes, and copies over the
33
+ first argument node's nn_module_stack.
34
+ """
35
+ # pyrefly: ignore [bad-assignment]
36
+ fake_mode = fake_mode or contextlib.nullcontext()
37
+
38
+ assert node.op == "call_function" and callable(node.target), (
39
+ f"node: {node}, target: {node.target}"
40
+ )
41
+
42
+ if (
43
+ isinstance(node.target, torch._ops.OpOverload)
44
+ and len(node.target._schema.returns) == 0
45
+ ):
46
+ node.meta["val"] = None
47
+ else:
48
+ fake_args, fake_kwargs = pytree.tree_map_only(
49
+ torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
50
+ )
51
+ # pyrefly: ignore [bad-context-manager]
52
+ with fake_mode, enable_python_dispatcher():
53
+ fake_res = node.target(*fake_args, **fake_kwargs)
54
+ node.meta["val"] = fake_res
55
+
56
+ if metadata is not None:
57
+ for k, v in metadata.items():
58
+ node.meta[k] = v
59
+
60
+ # Copy over metadata from argument nodes
61
+ arg_meta = [
62
+ arg.meta
63
+ for arg in pytree.tree_flatten((node.args, node.kwargs))[0]
64
+ if isinstance(arg, torch.fx.Node)
65
+ ]
66
+ if len(arg_meta) == 0:
67
+ return
68
+ arg_meta = arg_meta[0]
69
+
70
+ node.meta["nn_module_stack"] = node.meta.get(
71
+ "nn_module_stack",
72
+ arg_meta.get(
73
+ "nn_module_stack",
74
+ {
75
+ _EMPTY_NN_MODULE_STACK_KEY: (
76
+ _EMPTY_NN_MODULE_STACK_KEY,
77
+ _EMPTY_NN_MODULE_STACK_KEY,
78
+ )
79
+ },
80
+ ),
81
+ )
82
+
83
+ node.meta["torch_fn"] = node.meta.get(
84
+ "torch_fn",
85
+ (
86
+ f"{node.target.__name__}_0",
87
+ # pyrefly: ignore [missing-attribute]
88
+ f"{node.target.__class__.__name__}.{node.target.__name__}",
89
+ ),
90
+ )
91
+
92
+
93
+ @contextlib.contextmanager
94
+ def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
95
+ """
96
+ Takes a callable which will be called after we create a new node. The
97
+ callable takes the newly created node as input and returns None.
98
+ """
99
+ assert callable(f), "node_metadata_hook must be a callable."
100
+
101
+ # Add the hook to all submodules
102
+ for m in gm.modules():
103
+ if isinstance(m, GraphModule):
104
+ m._register_create_node_hook(f)
105
+ try:
106
+ yield
107
+ finally:
108
+ # Restore hook for all submodules
109
+ for m in gm.modules():
110
+ if isinstance(m, GraphModule):
111
+ m._unregister_create_node_hook(f)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import math
3
+ import operator
4
+ import traceback
5
+ from functools import partial
6
+ from typing import NamedTuple, TYPE_CHECKING
7
+
8
+ import sympy
9
+
10
+ import torch
11
+ import torch.fx
12
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
13
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
14
+ from torch.utils._sympy.numbers import int_oo
15
+ from torch.utils._sympy.value_ranges import ValueRanges
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import Callable
20
+
21
+
22
+ __all__ = ["InputDim"]
23
+
24
+
25
+ class InputDim(NamedTuple):
26
+ input_name: str
27
+ dim: int
28
+
29
+
30
+ def _convert_to_int(val):
31
+ # Convert simple sympy Integers into concrete int
32
+ if val in (sympy.oo, int_oo):
33
+ return math.inf
34
+ if val in (-sympy.oo, -int_oo):
35
+ return -math.inf
36
+ if isinstance(val, sympy.Integer):
37
+ return int(val)
38
+ raise RuntimeError("Export constraints cannot be non-integer expressions")
39
+
40
+
41
+ def _convert_range_to_int(range: ValueRanges):
42
+ assert isinstance(range, ValueRanges)
43
+ min_val = _convert_to_int(range.lower)
44
+ max_val = _convert_to_int(range.upper)
45
+ return min_val, max_val
46
+
47
+
48
+ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
49
+ def __init__(
50
+ self,
51
+ range_constraints: dict[sympy.Symbol, ValueRanges],
52
+ ):
53
+ super().__init__()
54
+ self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
55
+ self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set()
56
+ self.counter = 0
57
+
58
+ def _assert_range_constraint(self, node, lower, upper, assert_msg):
59
+ last_node = node
60
+ if lower > -math.inf:
61
+ last_node = self._insert_assert_async(
62
+ last_node, operator.ge, node, lower, assert_msg
63
+ )
64
+
65
+ if upper < math.inf:
66
+ last_node = self._insert_assert_async(
67
+ last_node, operator.le, node, upper, assert_msg
68
+ )
69
+
70
+ def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
71
+ """
72
+ Inserts assert_async call_function nodes in the graph. This function is
73
+ called **during** the interpreter-based pass.
74
+ """
75
+ self.counter += 1
76
+ graph = last_node.graph
77
+ with graph.inserting_after(last_node):
78
+ cmp = graph.call_function(op, (lower, upper), {})
79
+ with graph.inserting_after(cmp):
80
+ cmp_tensor = graph.call_function(
81
+ torch.ops.aten.scalar_tensor.default, (cmp,), {}
82
+ )
83
+ with graph.inserting_after(cmp_tensor):
84
+ assert_async = graph.call_function(
85
+ torch.ops.aten._assert_async.msg,
86
+ (cmp_tensor, assert_msg),
87
+ {},
88
+ )
89
+ return assert_async
90
+
91
+ def call(self, graph_module) -> PassResult:
92
+ self.existing_inline_assertions = _get_existing_inline_assertions(
93
+ graph_module, self.range_constraints
94
+ )
95
+
96
+ for module in graph_module.modules():
97
+ if not isinstance(module, torch.fx.GraphModule):
98
+ continue
99
+ for node in module.graph.nodes:
100
+ if node.op != "call_function":
101
+ continue
102
+ if "val" not in node.meta:
103
+ continue
104
+
105
+ val = node.meta["val"]
106
+ # In general, we may have to deal the case such as: ret[1].shape[0].
107
+ # We need first find out what symbols require assertion, then we need to follow the path
108
+ # from ret to the symbol, construct the proxies along the way and construct the messages
109
+ # piece-wise at the same time.
110
+ #
111
+ # We use post-order traversal to collect all the proxies callbacks needed, construct
112
+ # the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
113
+ # We need the callbacks because, in order to call the function to create a proxy for shape[0], we
114
+ # need the proxy for shape, which further requires the proxy for ret[1], etc.
115
+
116
+ def add_assertions(val):
117
+ call_backs: list[Callable] = []
118
+ messages: list[str] = []
119
+ if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
120
+ symbol = val.node.expr
121
+ if symbol in self.existing_inline_assertions:
122
+ return call_backs, messages
123
+ if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(
124
+ symbol
125
+ ):
126
+ if symbol in self._asserts_generated_unbacked_symbols:
127
+ return call_backs, messages
128
+ # We only care about unbacked symints for these inline
129
+ # constraints, which are prefixed with 'u'
130
+ constraint = self.range_constraints[symbol]
131
+ min_val, max_val = _convert_range_to_int(constraint)
132
+ assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
133
+ call_backs.append(
134
+ partial(
135
+ self._assert_range_constraint,
136
+ lower=min_val,
137
+ upper=max_val,
138
+ )
139
+ )
140
+ messages.append(assert_msg)
141
+ self._asserts_generated_unbacked_symbols.add(symbol)
142
+
143
+ elif isinstance(val, torch.Tensor):
144
+ for i, sym in enumerate(val.shape):
145
+ cbs, msgs = add_assertions(sym)
146
+ for cb, msg in zip(cbs, msgs):
147
+
148
+ def sym_size_cb(node, assert_msg, dim):
149
+ with node.graph.inserting_after(node):
150
+ dim_node = module.graph.call_function(
151
+ torch.ops.aten.sym_size.int,
152
+ (node, dim),
153
+ {},
154
+ )
155
+ cb(node=dim_node, assert_msg=assert_msg)
156
+
157
+ call_backs.append(partial(sym_size_cb, dim=i))
158
+ messages.append(f".shape[{i}]" + msg)
159
+ return call_backs, messages
160
+
161
+ callbacks, messages = add_assertions(val)
162
+ for cb, msg in zip(callbacks, messages):
163
+ cb(node=node, assert_msg=f"{node}" + msg)
164
+
165
+ module.recompile()
166
+
167
+ # Sometimes this pass would return a wrong graph where we have mismatched
168
+ # node names in signature. Before we fix it, let's just skip it.
169
+ if (
170
+ self.counter == 0
171
+ and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass
172
+ ):
173
+ return PassResult(graph_module, False)
174
+
175
+ # Populate the stack trace with dummy vals to respect IR
176
+ for node in graph_module.graph.nodes:
177
+ if not node.meta.get("stack_trace", None) and node.op not in [
178
+ "placeholder",
179
+ "output",
180
+ ]:
181
+ node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
182
+ return PassResult(graph_module, True)
183
+
184
+
185
+ def _get_existing_inline_assertions(
186
+ graph_module: torch.fx.GraphModule,
187
+ range_constraints: dict[sympy.Symbol, ValueRanges],
188
+ ) -> dict[sympy.Symbol, ValueRanges]:
189
+ existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {}
190
+
191
+ for module in graph_module.modules():
192
+ if not isinstance(module, torch.fx.GraphModule):
193
+ continue
194
+
195
+ # Find all the existing inline assertions. They will look something like:
196
+ # %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
197
+ # %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
198
+ # %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {})
199
+ for node in module.graph.nodes:
200
+ if node.target != torch.ops.aten._assert_scalar.default:
201
+ continue
202
+
203
+ compare_arg = node.args[0]
204
+ if not (
205
+ isinstance(compare_arg, torch.fx.Node)
206
+ and compare_arg.op == "call_function"
207
+ and compare_arg.target in (operator.le, operator.ge)
208
+ and len(compare_arg.args) == 2
209
+ ):
210
+ continue
211
+
212
+ compare_op = compare_arg.target
213
+ lhs, rhs = compare_arg.args
214
+
215
+ def maybe_get_symint(x):
216
+ if (
217
+ isinstance(x, torch.fx.Node)
218
+ and "val" in x.meta
219
+ and isinstance(x.meta["val"], torch.SymInt)
220
+ ):
221
+ return x.meta["val"].node.expr
222
+ return x
223
+
224
+ lhs = maybe_get_symint(lhs)
225
+ rhs = maybe_get_symint(rhs)
226
+
227
+ if compare_op is operator.ge:
228
+ lhs, rhs = rhs, lhs
229
+
230
+ if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int):
231
+ symint = lhs
232
+ scalar = rhs
233
+ elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int):
234
+ symint = rhs
235
+ scalar = lhs
236
+ else:
237
+ continue
238
+
239
+ if symint not in range_constraints:
240
+ raise RuntimeError(
241
+ f"Unable to find symint {symint} in {range_constraints}"
242
+ )
243
+
244
+ previous_range = existing_inline_assertions.get(
245
+ symint, ValueRanges(-math.inf, math.inf)
246
+ )
247
+
248
+ if symint is lhs:
249
+ bounds = ValueRanges(-math.inf, scalar)
250
+ else:
251
+ bounds = ValueRanges(scalar, math.inf)
252
+ existing_inline_assertions[symint] = previous_range & bounds
253
+
254
+ return existing_inline_assertions
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import operator
5
+ from typing import TYPE_CHECKING
6
+
7
+ import torch
8
+ from torch.export.exported_program import ConstantArgument, TensorArgument
9
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
10
+
11
+
12
+ if TYPE_CHECKING:
13
+ from torch.export.exported_program import ModuleCallSignature
14
+ from torch.export.graph_signature import ExportGraphSignature
15
+
16
+
17
+ __all__ = ["CollectTracepointsPass"]
18
+
19
+
20
+ class CollectTracepointsPass(PassBase):
21
+ """
22
+ Performs constant folding and constant propagation.
23
+ """
24
+
25
+ def __init__(
26
+ self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature
27
+ ) -> None:
28
+ super().__init__()
29
+ self.specs = specs
30
+ self.sig = sig
31
+
32
+ def call(self, gm: torch.fx.GraphModule) -> PassResult | None:
33
+ def get_arg_spec(arg) -> TensorArgument | ConstantArgument:
34
+ if isinstance(arg, torch.fx.Node):
35
+ if isinstance(arg.meta.get("val"), torch.Tensor):
36
+ return TensorArgument(name=arg.name)
37
+ else:
38
+ raise AssertionError(
39
+ "Symint input is not implemented yet for submodule call signature."
40
+ )
41
+ else:
42
+ return ConstantArgument(name="", value=arg)
43
+
44
+ for module in gm.modules():
45
+ if not isinstance(module, torch.fx.GraphModule):
46
+ continue
47
+ nn_module_stack = None
48
+ for node in module.graph.nodes:
49
+ if node.op != "call_function":
50
+ continue
51
+ if node.target is torch.ops.higher_order._export_tracepoint:
52
+ kind = node.kwargs["kind"]
53
+ if kind == "module_call_outputs":
54
+ nn_module_stack = node.meta["nn_module_stack"]
55
+ elif kind == "module_call_inputs":
56
+ nn_module_stack = None
57
+ else:
58
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
59
+ elif node.meta["nn_module_stack"] == nn_module_stack:
60
+ node.meta["nn_module_stack"].popitem()
61
+ else:
62
+ nn_module_stack = None
63
+ nn_module_stack = None
64
+ for node in reversed(module.graph.nodes):
65
+ if node.op != "call_function":
66
+ continue
67
+ if node.target is torch.ops.higher_order._export_tracepoint:
68
+ kind = node.kwargs["kind"]
69
+ if kind == "module_call_inputs":
70
+ nn_module_stack = node.meta["nn_module_stack"]
71
+ elif kind == "module_call_outputs":
72
+ nn_module_stack = None
73
+ else:
74
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
75
+ elif node.meta["nn_module_stack"] == nn_module_stack:
76
+ node.meta["nn_module_stack"].popitem()
77
+ else:
78
+ nn_module_stack = None
79
+
80
+ def copy_sig(sig) -> ModuleCallSignature:
81
+ from torch.export.exported_program import ModuleCallSignature
82
+
83
+ return ModuleCallSignature(
84
+ inputs=[],
85
+ outputs=[],
86
+ in_spec=sig.in_spec,
87
+ out_spec=sig.out_spec,
88
+ forward_arg_names=None,
89
+ )
90
+
91
+ for module in gm.modules():
92
+ if not isinstance(module, torch.fx.GraphModule):
93
+ continue
94
+ for node in module.graph.nodes:
95
+ if node.op != "call_function":
96
+ continue
97
+ if node.target is torch.ops.higher_order._export_tracepoint:
98
+ # There's some subtlety worth noting. Here fqn corresponds to
99
+ # the call name, whereas path corresponds to the module name.
100
+ # They are not necessarily the same! When a submodule is shared
101
+ # through different aliases, there are as many _export_tracepoint
102
+ # markers as there are aliases, since the shared submodule is
103
+ # wrapped once for each alias.
104
+ path = node.kwargs["path"]
105
+ fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
106
+
107
+ module_key = next(reversed(node.meta["nn_module_stack"]))
108
+ if "@" in module_key:
109
+ suffix = module_key.split("@")[-1]
110
+ path = f"{path}@{suffix}"
111
+
112
+ call_fqn = f"{fqn}@{suffix}"
113
+ if call_fqn not in self.specs:
114
+ self.specs[call_fqn] = copy_sig(self.specs[fqn])
115
+ fqn = call_fqn
116
+
117
+ kind = node.kwargs["kind"]
118
+ for i, arg in enumerate(node.args):
119
+ # We only update the signature of the alias used to call
120
+ # the submodule. Otherwise the signatures of all aliases
121
+ # would get conflated; the inputs/outputs of every call
122
+ # would be recorded in every other call as well.
123
+ if fqn == path:
124
+ if kind == "module_call_inputs":
125
+ self.specs[path].inputs.append(get_arg_spec(arg))
126
+ elif kind == "module_call_outputs":
127
+ self.specs[path].outputs.append(get_arg_spec(arg))
128
+ else:
129
+ raise AssertionError(f"Unknown tracepoint kind: {kind}")
130
+ if isinstance(arg, torch.fx.Node):
131
+ for user in node.users:
132
+ assert user.op == "call_function"
133
+ assert user.target is operator.getitem
134
+ assert isinstance(user.args[1], int)
135
+ if user.args[1] == i:
136
+ user.replace_all_uses_with(arg)
137
+ self.sig.replace_all_uses(user.name, arg.name)
138
+ break
139
+ users = list(node.users)
140
+ for user in users:
141
+ assert len(user.users) == 0
142
+ gm.graph.erase_node(user)
143
+ gm.graph.erase_node(node)
144
+ return PassResult(gm, True)
145
+
146
+ return None
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ from collections import defaultdict
4
+ from collections.abc import Callable
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+ import torch.utils._pytree as pytree
9
+
10
+
11
+ aten = torch.ops.aten
12
+
13
+ # We would like to split modules into two subgraphs for runtime weight updates to work correctly.
14
+ # The use case and more information could be found at:
15
+ # https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
16
+ META_TAG = "MODULE_TYPE"
17
+ MODULE_TAG = "_MAIN_MODULE"
18
+ CONST_MODULE_TAG = "_CONST_MODULE"
19
+
20
+
21
+ def replace_node_with_constant(gm, node, constant, name=None):
22
+ g = gm.graph
23
+
24
+ if name:
25
+ qualname = name
26
+ else:
27
+ if not hasattr(gm, "_frozen_param_count"):
28
+ gm._frozen_param_count = 0
29
+ i = gm._frozen_param_count
30
+
31
+ while True:
32
+ qualname = f"_frozen_param{i}"
33
+ if not hasattr(gm, qualname):
34
+ break
35
+ i += 1
36
+
37
+ gm._frozen_param_count = i + 1
38
+
39
+ with g.inserting_before(node):
40
+ new_input_node = g.create_node("get_attr", qualname, (), {})
41
+ node.replace_all_uses_with(new_input_node)
42
+ new_input_node.meta.update(node.meta)
43
+ g.erase_node(node)
44
+
45
+ # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
46
+ gm.register_buffer(qualname, constant)
47
+ setattr(gm, qualname, constant)
48
+
49
+
50
+ class ConstantFolder(torch.fx.Interpreter):
51
+ def __init__(
52
+ self,
53
+ gm: torch.fx.GraphModule,
54
+ skip_constructors: bool = False,
55
+ ):
56
+ super().__init__(gm)
57
+ self.node_replacements: dict[torch.fx.Node, Any] = {}
58
+ self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
59
+ self.unknown_value = object()
60
+ self.skip_constructors: bool = skip_constructors
61
+
62
+ # overwrite this to deallocate env values if their only remaining use
63
+ # is the output
64
+ self.user_to_last_uses = self.node_to_last_non_output_use()
65
+
66
+ def is_impure(self, node: torch.fx.Node) -> bool:
67
+ if (
68
+ node.target is torch.ops.prims.convert_element_type.default
69
+ and node.args[0].op == "get_attr" # type: ignore[union-attr]
70
+ and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
71
+ and node.args[1] == torch.bfloat16
72
+ ):
73
+ # For int8_weight -> dq -> bf16_weight
74
+ return True
75
+ if node.target in [
76
+ torch.ops.quantized_decomposed.dequantize_per_channel.default,
77
+ torch.ops.quantized_decomposed.dequantize_per_tensor.default,
78
+ torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
79
+ torch.ops.pt2e_quant.dequantize_affine,
80
+ ]:
81
+ # For the pattern fp32_weight -> q -> dq
82
+ # We only folding fp32_weight -> q
83
+ # int8_weight and leave dq in graph to be fused
84
+ return True
85
+ return False
86
+
87
+ def node_to_last_non_output_use(self):
88
+ last_non_output_use = collections.defaultdict(list)
89
+ seen_uses = set()
90
+ output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
91
+
92
+ for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
93
+ if node.target == "output":
94
+ continue
95
+
96
+ def add_use(inp):
97
+ if inp in seen_uses:
98
+ return
99
+
100
+ seen_uses.add(inp)
101
+ last_non_output_use[node].append(inp)
102
+
103
+ # In-place is fine since we don't mutate
104
+ pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
105
+
106
+ # if this node is only used in output, we want to gc it right away
107
+ if len(node.users) == 1 and output_node in node.users:
108
+ last_non_output_use[node].append(node)
109
+
110
+ return last_non_output_use
111
+
112
+ def run_node(self, node):
113
+ if node.target == "output":
114
+ # because we remove nodes from env on last non output use,
115
+ # re-define them now or we'll get error in interpreter
116
+ def set_env(arg):
117
+ self.env[arg] = self.unknown_value
118
+
119
+ # In-place is fine since we don't mutate
120
+ pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
121
+ return super().run_node(node)
122
+
123
+ args, kwargs = self.fetch_args_kwargs_from_env(node)
124
+ flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
125
+
126
+ # We need to do this weird thing because in cases where flattened_inputs
127
+ # contains a ScriptObject, equality checking results in a type error if
128
+ # the types are different.
129
+ if any(
130
+ type(self.unknown_value) is type(input_) and self.unknown_value == input_
131
+ for input_ in flattened_inputs
132
+ ):
133
+ return self.unknown_value
134
+
135
+ # TODO - fix errors with this
136
+ if (
137
+ node.op == "call_function"
138
+ and node.target is aten._efficientzerotensor.default
139
+ ):
140
+ return self.unknown_value
141
+
142
+ # TODO - constant folding triton kernel returns the inputs -- fix this
143
+ if (
144
+ node.op == "call_function"
145
+ and node.name == "triton_kernel_wrapper_functional_proxy"
146
+ ):
147
+ return self.unknown_value
148
+
149
+ # skip constructors, since inductor generates optimal code for them already
150
+ # and turning into tensor would result in an additional global memory read
151
+ # TODO - more complicated strategy
152
+ if (
153
+ self.skip_constructors
154
+ and node.op != "get_attr"
155
+ and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
156
+ ):
157
+ return self.unknown_value
158
+
159
+ # All mutations should either be removed or on inputs which we did not make constant
160
+ if (
161
+ isinstance(node.target, torch._ops.OpOverload)
162
+ and torch.Tag.nondeterministic_seeded in node.target.tags
163
+ ):
164
+ return self.unknown_value
165
+
166
+ out = super().run_node(node)
167
+
168
+ if node.op != "get_attr" and isinstance(out, torch.Tensor):
169
+ if out.device.type == "meta":
170
+ return out
171
+
172
+ if not self.insertable_tensor_check(out):
173
+ return out
174
+
175
+ if self.is_impure(node):
176
+ return self.unknown_value
177
+
178
+ self.add_node_replacement(node, out)
179
+
180
+ flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
181
+
182
+ for n in flattened_node_inps:
183
+ if not isinstance(n, torch.fx.Node):
184
+ continue
185
+
186
+ self.replaced_uses[n] += 1
187
+
188
+ for to_delete in self.user_to_last_uses.get(node, []):
189
+ if self.replaced_uses[to_delete] == len(to_delete.users):
190
+ self.node_replacements.pop(to_delete, None)
191
+
192
+ return out
193
+
194
+ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
195
+ return True
196
+
197
+ def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
198
+ self.node_replacements[node] = tensor
199
+
200
+ def run(self): # type: ignore[override]
201
+ env = {}
202
+ for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
203
+ env[n] = self.unknown_value
204
+ return super().run(initial_env=env)
205
+
206
+
207
+ def constant_fold(
208
+ gm: torch.fx.GraphModule,
209
+ constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
210
+ ):
211
+ with torch.utils._python_dispatch._disable_current_modes():
212
+ cf = ConstantFolder(gm, skip_constructors=True)
213
+ cf.run()
214
+
215
+ for node, constant in cf.node_replacements.items():
216
+ if constraint_fn is not None and not constraint_fn(node):
217
+ continue
218
+ replace_node_with_constant(gm, node, constant)
219
+
220
+ erased_params = []
221
+ # Get all attr users by looking up the graph instead from node.users, because in this case
222
+ # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
223
+
224
+ # opcode name target args kwargs
225
+ # ------------- ------------------- ---------------- --------------------------- --------
226
+ # placeholder arg0_1 arg0 () {}
227
+ # get_attr _tensor_constant0 state () {}
228
+ # call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
229
+ # get_attr _tensor_constant0_1 state () {}
230
+ # call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
231
+ # output output output ([add],) {}
232
+
233
+ get_attr_node_users = defaultdict(list)
234
+ for node in gm.graph.nodes:
235
+ if node.op == "get_attr":
236
+ get_attr_node_users[node.target].extend(node.users.keys())
237
+ for node in gm.graph.find_nodes(op="get_attr"):
238
+ if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
239
+ if hasattr(gm, node.target):
240
+ delattr(gm, node.target)
241
+ erased_params.append(node)
242
+ for node in erased_params:
243
+ gm.graph.erase_node(node)
244
+
245
+ gm.graph.eliminate_dead_code()
246
+ gm.graph.lint()
247
+ gm.recompile()
248
+
249
+
250
+ def constant_graph_tag(gm: torch.fx.GraphModule) -> None:
251
+ with torch.utils._python_dispatch._disable_current_modes():
252
+ cf = ConstantFolder(gm, skip_constructors=True)
253
+ cf.run()
254
+
255
+ for node in gm.graph.nodes:
256
+ if (
257
+ node.op == "get_attr"
258
+ or node in cf.node_replacements
259
+ or node in cf.replaced_uses
260
+ ):
261
+ node.meta[META_TAG] = CONST_MODULE_TAG
262
+ else:
263
+ node.meta[META_TAG] = MODULE_TAG
264
+
265
+
266
+ def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
267
+ """
268
+ Construct a GraphModule which corresponds to the part which could be
269
+ constant folded in provided gm.
270
+ """
271
+
272
+ constant_graph_tag(gm)
273
+ # We rewrite the tags, if it's a constant being directly consumed, without
274
+ # any folding opportunity, we keep it in main gm.
275
+ for node in gm.graph.find_nodes(op="get_attr"):
276
+ used_to_fold = False
277
+ for u in node.users:
278
+ if u.meta[META_TAG] == CONST_MODULE_TAG:
279
+ used_to_fold = True
280
+ break
281
+ if not used_to_fold:
282
+ node.meta[META_TAG] = MODULE_TAG
283
+
284
+ new_graph = torch.fx.Graph()
285
+
286
+ node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
287
+ output_nodes = []
288
+ for node in gm.graph.nodes:
289
+ if node.meta[META_TAG] == MODULE_TAG:
290
+ continue
291
+
292
+ new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
293
+ node_remapping[node] = new_node
294
+
295
+ for user in node.users:
296
+ if user.meta[META_TAG] == MODULE_TAG:
297
+ output_nodes.append(new_node)
298
+ break
299
+
300
+ new_graph.output(tuple(output_nodes))
301
+ new_graph.lint()
302
+ new_gm = torch.fx.GraphModule(gm, new_graph)
303
+
304
+ return new_gm
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch._export.pass_base import (
6
+ _ExportPassBaseDeprecatedDoNotUse,
7
+ Argument,
8
+ PassResult,
9
+ )
10
+ from torch._export.pass_infra.node_metadata import NodeMetadata
11
+ from torch._export.pass_infra.proxy_value import ProxyValue
12
+ from torch._ops import OpOverload
13
+
14
+
15
+ aten = torch.ops.aten
16
+
17
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
18
+ aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
19
+ aten._assert_async.msg: aten._functional_assert_async.msg,
20
+ }
21
+
22
+
23
+ class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
24
+ """
25
+ Functionalize ops with side effect in graph module by replacing the op with
26
+ functional version of it. A new dependency token (`dep_token`) will be
27
+ created and propagated through functional ops to output.
28
+ For example:
29
+ ```
30
+ def f(x):
31
+ sym_constrain_range(x.shape[0], min=1, max=3)
32
+ return x.add(3)
33
+ ```
34
+ Will be transformed to:
35
+ ```
36
+ def f(x):
37
+ dep_token0 = _make_dep_token()
38
+ dep_token1 = _functional_sym_constrain_range(
39
+ x.shape[0], min=1, max=3, dep_token=dep_token0
40
+ )
41
+
42
+ return x.add(3), dep_token1
43
+ ```
44
+ """
45
+
46
+ def __init__(self) -> None:
47
+ super().__init__()
48
+ self._dep_token: Optional[ProxyValue] = None
49
+ self._next_dep_token_index: Optional[int] = None
50
+
51
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
52
+ # Early return if no non-functional assertions.
53
+ if not any(
54
+ n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
55
+ for n in graph_module.graph.nodes
56
+ ):
57
+ return PassResult(graph_module=graph_module, modified=False)
58
+
59
+ gm = copy.deepcopy(graph_module)
60
+ self._dep_token = None
61
+ self._next_dep_token_index = None
62
+ return super().call(gm)
63
+
64
+ def call_operator(
65
+ self,
66
+ op: OpOverload,
67
+ args: tuple[Argument, ...],
68
+ kwargs: dict[str, Argument],
69
+ meta: NodeMetadata,
70
+ ) -> ProxyValue:
71
+ if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
72
+ return super().call_operator(op, args, kwargs, meta)
73
+
74
+ if self._dep_token is None:
75
+ self._dep_token = super().call_operator(
76
+ aten._make_dep_token,
77
+ args=(),
78
+ kwargs={},
79
+ meta=self._create_dummy_node_metadata(),
80
+ )
81
+ self._dep_token.node.name = "dep_token0"
82
+ self._next_dep_token_index = 1
83
+
84
+ self._dep_token = super().call_operator(
85
+ _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
86
+ args=args,
87
+ kwargs={**kwargs, "dep_token": self._dep_token},
88
+ meta=meta,
89
+ )
90
+ assert self._next_dep_token_index is not None
91
+ self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
92
+ self._next_dep_token_index += 1
93
+
94
+ return self._dep_token
95
+
96
+ def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
97
+ assert self._dep_token is not None
98
+
99
+ return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from collections import defaultdict
3
+
4
+ import torch
5
+ from torch._export.passes._node_metadata_hook import (
6
+ _node_metadata_hook,
7
+ _set_node_metadata_hook,
8
+ )
9
+ from torch._library.fake_profile import OpProfile, TensorMetadata
10
+
11
+
12
+ def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> None:
13
+ """
14
+ This is used by draft_export to insert guards in front of calls to custom
15
+ operators which have a generated fake kernel.
16
+ """
17
+ for node in gm.graph.nodes:
18
+ if node.op == "call_function" and str(node.target) in ops_to_guard:
19
+ with (
20
+ _set_node_metadata_hook(
21
+ gm,
22
+ functools.partial(
23
+ _node_metadata_hook,
24
+ metadata={"stack_trace": node.meta.get("stack_trace")},
25
+ ),
26
+ ),
27
+ gm.graph.inserting_before(node),
28
+ ):
29
+ for arg in (*node.args, *node.kwargs.values()):
30
+ if isinstance(arg, torch.fx.Node) and isinstance(
31
+ arg.meta.get("val"), torch.Tensor
32
+ ):
33
+ val = arg.meta["val"]
34
+ gm.graph.call_function(
35
+ torch.ops.aten._assert_tensor_metadata.default,
36
+ args=(arg,),
37
+ kwargs={
38
+ "dtype": val.dtype,
39
+ "device": val.device,
40
+ "layout": val.layout,
41
+ },
42
+ )
43
+
44
+ gm.recompile()
45
+
46
+
47
+ def get_op_profiles(
48
+ gm: torch.fx.GraphModule, ops_to_guard: set[str]
49
+ ) -> dict[str, set[OpProfile]]:
50
+ """
51
+ This is used by draft_export to get a list of custom operator profiles so
52
+ that we can generate fake kernels.
53
+ """
54
+
55
+ def _get_op_profile(node: torch.fx.Node) -> OpProfile:
56
+ args_profile = tuple(
57
+ TensorMetadata.maybe_from_tensor(arg.meta.get("val"))
58
+ if isinstance(arg, torch.fx.Node)
59
+ else None
60
+ for arg in (*node.args, *node.kwargs.values())
61
+ )
62
+
63
+ out_profile = None
64
+ meta = node.meta.get("val")
65
+ assert meta is not None
66
+ if isinstance(meta, torch.Tensor):
67
+ out_profile = TensorMetadata.maybe_from_tensor(meta)
68
+ elif isinstance(meta, (list, tuple)):
69
+ out_profile = tuple(TensorMetadata.maybe_from_tensor(m) for m in meta) # type: ignore[assignment]
70
+ assert out_profile is not None
71
+
72
+ return OpProfile(args_profile, out_profile) # type: ignore[arg-type]
73
+
74
+ op_profiles: dict[str, set[OpProfile]] = defaultdict(set)
75
+
76
+ for node in gm.graph.nodes:
77
+ if node.op == "call_function" and str(node.target) in ops_to_guard:
78
+ op_profiles[str(node.target)].add(_get_op_profile(node))
79
+
80
+ return op_profiles
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import logging
4
+ from typing import Any, Optional, Union
5
+
6
+ import torch
7
+ from torch._export.verifier import SpecViolationError
8
+ from torch._guards import detect_fake_mode
9
+ from torch._library.fake_class_registry import FakeScriptObject
10
+ from torch._library.opaque_object import is_opaque_reference_type
11
+ from torch._subclasses.fake_tensor import unset_fake_temporarily
12
+ from torch.export.exported_program import (
13
+ ArgumentSpec,
14
+ CustomObjArgument,
15
+ ExportGraphSignature,
16
+ InputKind,
17
+ InputSpec,
18
+ TensorArgument,
19
+ )
20
+ from torch.fx._symbolic_trace import _ConstantAttributeType
21
+ from torch.fx.graph_module import _get_attr
22
+
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ class ConstantAttrMap(collections.abc.MutableMapping):
28
+ """A mapping class that understands how to use module constants (tensors,
29
+ ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
30
+ but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to
31
+ the same underlying value (but we guarantee that they will `hash()` to the same value
32
+ if that's the case).
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ # Underlying dict that we use to implement this mapping.
37
+ self._constant_attrs: dict[
38
+ Union[int, torch.Tensor, FakeScriptObject, torch.utils._pytree.TreeSpec],
39
+ list[Any],
40
+ ] = {}
41
+ # Map from the hash(ScriptObject) to the ScriptObject itself. Used for
42
+ # APIs like `__iter__` that should look like they're returning the
43
+ # original ScriptObjects.
44
+ self._script_object_map: dict[int, torch.ScriptObject] = {}
45
+
46
+ def __getitem__(self, key: _ConstantAttributeType) -> Any:
47
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
48
+ assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
49
+ return self._constant_attrs[real_key]
50
+
51
+ def __setitem__(self, key: _ConstantAttributeType, value):
52
+ # we shouldn't actually call this, should go to add() instead to handle aliasing
53
+ raise NotImplementedError(
54
+ """Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
55
+ The same key can be mapped to multiple values, for handling constant aliasing."""
56
+ )
57
+
58
+ def add(self, key: _ConstantAttributeType, value: Any) -> None:
59
+ if isinstance(key, torch.ScriptObject):
60
+ if hash(key) not in self._constant_attrs:
61
+ self._constant_attrs[hash(key)] = []
62
+ self._constant_attrs[hash(key)].append(value)
63
+ self._script_object_map[hash(key)] = key
64
+ elif isinstance(key, (torch.Tensor, FakeScriptObject)):
65
+ if key not in self._constant_attrs:
66
+ self._constant_attrs[key] = []
67
+ self._constant_attrs[key].append(value)
68
+ else:
69
+ raise TypeError(
70
+ f"Expected key to be a tensor or ScriptObject, got {type(key)}"
71
+ )
72
+
73
+ def __delitem__(self, key: _ConstantAttributeType):
74
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
75
+
76
+ del self._constant_attrs[real_key]
77
+
78
+ def __iter__(self):
79
+ for key in self._constant_attrs:
80
+ if isinstance(key, int):
81
+ yield self._script_object_map[key]
82
+ else:
83
+ yield key
84
+
85
+ def __len__(self):
86
+ return len(self._constant_attrs)
87
+
88
+ def __contains__(self, key: object) -> bool:
89
+ real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
90
+ return real_key in self._constant_attrs
91
+
92
+
93
+ def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
94
+ # The FQN of the constant tensor in the state dict should
95
+ # correspond to the module where the constant tensor was
96
+ # originally used.
97
+ if len(node.meta["nn_module_stack"]) == 0:
98
+ return constant_name
99
+ parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0]
100
+ if len(parent_fqn) > 0:
101
+ return f"{parent_fqn}.{constant_name}"
102
+ else:
103
+ return constant_name
104
+
105
+
106
+ def _get_first_fqn(
107
+ const_attrs: ConstantAttrMap,
108
+ key: _ConstantAttributeType,
109
+ ) -> Any:
110
+ fqns = const_attrs.get(key)
111
+ return fqns[0] if fqns else None
112
+
113
+
114
+ def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
115
+ """
116
+ If there is a tensor constant created while tracing, here is how the graph
117
+ looks like:
118
+
119
+ %_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
120
+ %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,))
121
+ %detach_ : [num_users=?] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,))
122
+
123
+ To check to see if the tensor constant is being used, we want to traverse to
124
+ the detach node to see if it's actually being used.
125
+
126
+ This function returns None if this constant is being used, otherwise it returns the
127
+ lift_fresh and detach node to be removed later.
128
+ """ # noqa: B950
129
+ if len(node.users) > 1:
130
+ return None
131
+
132
+ lift_fresh_node = next(iter(node.users.keys()))
133
+ if not (
134
+ lift_fresh_node.op == "call_function"
135
+ and lift_fresh_node.target
136
+ in (
137
+ torch.ops.aten.lift_fresh.default,
138
+ torch.ops.aten.lift_fresh_copy.default,
139
+ )
140
+ ):
141
+ return None
142
+
143
+ if len(lift_fresh_node.users) > 1:
144
+ return None
145
+
146
+ # Case 1: lift node is not used anywhere
147
+ if len(lift_fresh_node.users) == 0:
148
+ return [lift_fresh_node, node]
149
+
150
+ detach_node = next(iter(lift_fresh_node.users.keys()))
151
+ if not (
152
+ detach_node.op == "call_function"
153
+ and detach_node.target
154
+ in (
155
+ torch.ops.aten.detach_.default,
156
+ torch.ops.aten.detach.default,
157
+ )
158
+ ):
159
+ return None
160
+
161
+ if len(detach_node.users) > 0:
162
+ return None
163
+ else:
164
+ # Case 2: Lift node's child is not used anywhere
165
+ return [detach_node, lift_fresh_node, node]
166
+
167
+
168
+ def lift_constants_pass(
169
+ gm: torch.fx.GraphModule,
170
+ graph_signature: ExportGraphSignature,
171
+ constant_attrs: ConstantAttrMap,
172
+ ) -> dict[str, _ConstantAttributeType]:
173
+ """
174
+ Takes a graph module, graph signature, and modifies them inplace to lift any
175
+ constants (tensors or custom classes) as inputs to the graph. Returns a
176
+ dictionary of names to constants.
177
+
178
+ Arguments:
179
+ gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift.
180
+ graph_signature (ExportGraphSignature): This graph signature will be
181
+ mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs.
182
+ constant_attrs (ConstantAttr): A mapping from a constant value to its
183
+ fully-qualified path in `gm`. This is used to maintain consistent
184
+ location of constants between the original module and the exported
185
+ version.
186
+
187
+ Returns:
188
+ A dictionary of fqn => constant value.
189
+ """
190
+ all_constants: dict[str, _ConstantAttributeType] = {}
191
+
192
+ input_specs = graph_signature.input_specs
193
+ num_custom_obj = sum(
194
+ input_spec.kind == InputKind.CUSTOM_OBJ for input_spec in input_specs
195
+ )
196
+ num_tensor_constants = sum(
197
+ input_spec.kind == InputKind.CONSTANT_TENSOR for input_spec in input_specs
198
+ )
199
+
200
+ fake_mode = detect_fake_mode(
201
+ tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
202
+ )
203
+
204
+ first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes))
205
+ used_target_names = set()
206
+
207
+ input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
208
+ assert len(input_nodes) == len(input_specs)
209
+ for i, (node, input_spec) in enumerate(zip(input_nodes, input_specs)):
210
+ used_target_names.add(input_spec.target)
211
+ if input_spec.kind == InputKind.USER_INPUT:
212
+ first_user_input = node
213
+ first_user_input_loc = i
214
+ break
215
+
216
+ lifted_objs = ConstantAttrMap()
217
+ renamed_targets = {}
218
+ for node in list(gm.graph.nodes):
219
+ if node.op == "get_attr":
220
+ if nodes_to_remove := _unused_constant(node):
221
+ # Remove the node if it's not being used
222
+ for node_rm in nodes_to_remove:
223
+ gm.graph.erase_node(node_rm)
224
+ continue
225
+
226
+ constant_val = _get_attr(gm, node.target)
227
+ # These are not hashable and not gonna be lifted
228
+ # so we can skip them earlier
229
+ if isinstance(constant_val, torch.fx.GraphModule):
230
+ continue
231
+ if "LoweredBackendModule" in type(constant_val).__name__:
232
+ continue
233
+ if "AOTInductorRunnerWrapper" in type(constant_val).__name__:
234
+ continue
235
+ if isinstance(constant_val, torch.utils._pytree.TreeSpec):
236
+ continue
237
+
238
+ if constant_val in lifted_objs:
239
+ # We already lifted this constant elsewhere. Just rewrite uses
240
+ # of this get_attr to point to the already-existing placeholder
241
+ # node.
242
+ const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
243
+ node.replace_all_uses_with(const_placeholder_node)
244
+ gm.graph.erase_node(node)
245
+ renamed_targets[node.name] = const_placeholder_node.name
246
+ continue
247
+
248
+ # For ScriptObject, Tensor and FakeScriptObject constants:
249
+ # First check if the constant was an attribute on some module by
250
+ # consulting `constant_attrs` map. If it is, use the fqn that keeps
251
+ # its location consistent with the eager module.
252
+ #
253
+ # If it's not in the `constant_attrs` map, that means it's an inline
254
+ # constant (e.g. x + torch.tensor(0)), and thus did not have a
255
+ # specific location in the eager module. In that case, just generate
256
+ # some name and attach it to the module in which it was used.
257
+ if isinstance(
258
+ constant_val, (torch.ScriptObject, FakeScriptObject)
259
+ ) or is_opaque_reference_type(type(constant_val)):
260
+ constant_kind = InputKind.CUSTOM_OBJ
261
+ constant_fqn = _get_first_fqn(constant_attrs, constant_val)
262
+ if constant_fqn is not None:
263
+ constant_name = constant_fqn.replace(".", "_")
264
+ else:
265
+ constant_name = f"lifted_custom_{num_custom_obj}"
266
+ constant_fqn = get_constant_fqn(node, constant_name)
267
+ while constant_fqn in used_target_names:
268
+ num_custom_obj += 1
269
+ constant_name = f"lifted_custom_{num_custom_obj}"
270
+ constant_fqn = get_constant_fqn(node, constant_name)
271
+ num_custom_obj += 1
272
+ elif isinstance(constant_val, torch.Tensor):
273
+ # Remove the parameterness of constant_val
274
+ if isinstance(constant_val, torch.nn.Parameter):
275
+ log.debug(
276
+ "%s created when tracing %s is a parameter. But "
277
+ "it's not registered with register_parameter(). export will treat it as a constant tensor",
278
+ str(node.target),
279
+ str(node.meta.get("stack_trace", "<unknown stack>")),
280
+ )
281
+ # We get the real data out of the parameter by disabling the surrounding fake mode.
282
+ with unset_fake_temporarily():
283
+ constant_val = constant_val.data
284
+ constant_kind = InputKind.CONSTANT_TENSOR
285
+ constant_fqn = _get_first_fqn(constant_attrs, constant_val)
286
+ if constant_fqn is not None:
287
+ constant_name = constant_fqn.replace(".", "_")
288
+ else:
289
+ constant_name = f"lifted_tensor_{num_tensor_constants}"
290
+ constant_fqn = get_constant_fqn(node, constant_name)
291
+ while constant_fqn in used_target_names:
292
+ num_tensor_constants += 1
293
+ constant_name = f"lifted_tensor_{num_tensor_constants}"
294
+ constant_fqn = get_constant_fqn(node, constant_name)
295
+ num_tensor_constants += 1
296
+ else:
297
+ raise SpecViolationError(
298
+ f"getattr node {node} referencing unsupported type {type(constant_val)}"
299
+ )
300
+
301
+ with gm.graph.inserting_before(first_user_input):
302
+ # Insert the constant node before the first user input
303
+ const_placeholder_node = gm.graph.placeholder(constant_name)
304
+ # match target name with its node name in case there is name collision
305
+ # and suffix is added to node name in fx
306
+ const_placeholder_node.target = const_placeholder_node.name
307
+
308
+ for k, v in node.meta.items():
309
+ const_placeholder_node.meta[k] = v
310
+
311
+ # Once the FQN has been used, remove nn_module_stack, stack_trace
312
+ const_placeholder_node.meta.pop("nn_module_stack")
313
+ const_placeholder_node.meta.pop("stack_trace", None)
314
+
315
+ input_spec_arg: ArgumentSpec
316
+ if isinstance(constant_val, torch.Tensor):
317
+ if fake_mode is not None:
318
+ const_placeholder_node.meta["val"] = fake_mode.from_tensor(
319
+ constant_val, static_shapes=True
320
+ )
321
+ const_placeholder_node.meta["val"].constant = constant_val
322
+ else:
323
+ const_placeholder_node.meta["val"] = constant_val
324
+ input_spec_arg = TensorArgument(name=const_placeholder_node.name)
325
+ elif isinstance(constant_val, torch._C.ScriptObject):
326
+ class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
327
+ const_placeholder_node.meta["val"] = CustomObjArgument(
328
+ constant_fqn, class_fqn
329
+ )
330
+ input_spec_arg = CustomObjArgument(
331
+ name=const_placeholder_node.name, class_fqn=class_fqn
332
+ )
333
+ elif isinstance(constant_val, FakeScriptObject):
334
+ class_fqn = constant_val.script_class_name
335
+ const_placeholder_node.meta["val"] = CustomObjArgument(
336
+ constant_fqn, class_fqn, constant_val
337
+ )
338
+ input_spec_arg = CustomObjArgument(
339
+ name=const_placeholder_node.name,
340
+ class_fqn=class_fqn,
341
+ fake_val=constant_val,
342
+ )
343
+ else:
344
+ raise SpecViolationError(
345
+ f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
346
+ )
347
+
348
+ lifted_objs.add(constant_val, const_placeholder_node)
349
+ node.replace_all_uses_with(const_placeholder_node)
350
+ gm.graph.erase_node(node)
351
+
352
+ renamed_targets[node.name] = const_placeholder_node.name
353
+
354
+ # Add the constant as a buffer to the graph signature
355
+ graph_signature.input_specs.insert(
356
+ first_user_input_loc,
357
+ InputSpec(
358
+ kind=constant_kind,
359
+ arg=input_spec_arg,
360
+ target=constant_fqn,
361
+ ),
362
+ )
363
+ if constant_val in constant_attrs:
364
+ for fqn in constant_attrs[constant_val]:
365
+ all_constants[fqn] = constant_val
366
+ else:
367
+ all_constants[constant_fqn] = constant_val
368
+ first_user_input_loc += 1
369
+
370
+ for spec in graph_signature.output_specs:
371
+ if spec.arg.name in renamed_targets:
372
+ spec.arg.name = renamed_targets[spec.arg.name]
373
+
374
+ return all_constants
375
+
376
+
377
+ def rewrite_script_object_meta(
378
+ gm: torch.fx.GraphModule,
379
+ ) -> dict[str, _ConstantAttributeType]:
380
+ """When tracing, we produce a graph with FakeScriptObject in the
381
+ meta["val"].
382
+
383
+ For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
384
+ """
385
+ constants: dict[
386
+ str,
387
+ _ConstantAttributeType,
388
+ ] = {}
389
+ for node in gm.graph.nodes:
390
+ if "val" not in node.meta:
391
+ continue
392
+
393
+ old_meta = node.meta["val"]
394
+
395
+ if isinstance(old_meta, torch.ScriptObject):
396
+ class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
397
+ new_meta = CustomObjArgument(node.name, class_fqn)
398
+ constants[node.name] = old_meta
399
+ node.meta["val"] = new_meta
400
+
401
+ elif isinstance(old_meta, FakeScriptObject):
402
+ class_fqn = old_meta.script_class_name # type: ignore[attr-defined]
403
+ new_meta = CustomObjArgument(node.name, class_fqn, old_meta)
404
+ constants[node.name] = old_meta
405
+ node.meta["val"] = new_meta
406
+
407
+ return constants
408
+
409
+
410
+ def _materialize_and_lift_constants(
411
+ gm: torch.fx.GraphModule,
412
+ export_graph_signature: ExportGraphSignature,
413
+ constant_attrs: ConstantAttrMap,
414
+ ) -> dict[str, _ConstantAttributeType]:
415
+ constants = rewrite_script_object_meta(gm)
416
+ constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
417
+ return constants
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.fx.passes.infra.pass_base import PassBase, PassResult
3
+
4
+
5
+ class _RemoveRuntimeAssertionsPass(PassBase):
6
+ """
7
+ Remove runtime assertions inserted by the
8
+ _AddRuntimeAssertionsForInlineConstraintsPass.
9
+ """
10
+
11
+ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
12
+ modified = False
13
+ for module in graph_module.modules():
14
+ if not isinstance(module, torch.fx.GraphModule):
15
+ continue
16
+ for node in module.graph.nodes:
17
+ if node.target in [
18
+ torch.ops.aten._assert_async.msg,
19
+ torch.ops.aten._assert_scalar.default,
20
+ torch.ops.aten.sym_constrain_range_for_size.default,
21
+ torch.ops.aten.sym_constrain_range.default,
22
+ torch.ops.aten._assert_tensor_metadata.default,
23
+ ]:
24
+ assert_async_node = node
25
+ if len(assert_async_node.users) > 0:
26
+ continue
27
+ module.graph.erase_node(assert_async_node)
28
+ # the upstream scalar_tensor <- {le, ge} <- sym_size
29
+ # linear chain of nodes of nodes is removed by the
30
+ # downstream dead code elimination
31
+ modified = True
32
+
33
+ # We don't necessarily want to run DCE here because it could affect
34
+ # nodes that are in the module_call_graph attribute of the exported
35
+ # program. We will leave it to the pass caller to call DCE.
36
+ return PassResult(graph_module, modified)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ from torch._higher_order_ops.wrap import wrap_with_autocast
8
+
9
+ from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split
10
+ from .replace_with_hop_pass_util import (
11
+ _replace_with_hop_helper,
12
+ _replace_with_hop_pass_helper,
13
+ _sequential_split_and_maybe_inline_subgraphs_helper,
14
+ )
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from torch.export.graph_signature import ExportGraphSignature
19
+
20
+
21
+ def _is_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
22
+ return (
23
+ node
24
+ and node.op == "call_function"
25
+ and node.target
26
+ in [
27
+ torch.amp.autocast_mode._enter_autocast,
28
+ torch.amp.autocast_mode._exit_autocast,
29
+ ]
30
+ )
31
+
32
+
33
+ def _is_enter_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
34
+ return (
35
+ node
36
+ and node.op == "call_function"
37
+ and node.target is torch.amp.autocast_mode._enter_autocast
38
+ )
39
+
40
+
41
+ def _is_exit_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
42
+ return (
43
+ node
44
+ and node.op == "call_function"
45
+ and node.target is torch.amp.autocast_mode._exit_autocast
46
+ )
47
+
48
+
49
+ def _is_autocast_sub_mod(node: torch.fx.Node) -> bool:
50
+ """
51
+ Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`.
52
+ """
53
+ if node.op == "call_module":
54
+ assert isinstance(node.target, str)
55
+ subgm = getattr(node.graph.owning_module, node.target)
56
+ first_non_ph = nodes_first(
57
+ subgm.graph.nodes, lambda node: node.op != "placeholder"
58
+ )
59
+ if (
60
+ first_non_ph
61
+ and first_non_ph.op == "call_function"
62
+ and first_non_ph.target is torch.amp.autocast_mode._enter_autocast
63
+ ):
64
+ # TODO: check if current auto-cast type is the same as the args of
65
+ # _enter_autocast. If so, return False, i.e. do not create a submodule.
66
+ return True
67
+ return False
68
+
69
+
70
+ def _check_valid_autocast_block(
71
+ enter_autocast_node: torch.fx.Node, exit_autocast_node: torch.fx.Node
72
+ ) -> None:
73
+ assert _is_enter_autocast_node(enter_autocast_node)
74
+ assert _is_exit_autocast_node(exit_autocast_node)
75
+ assert exit_autocast_node.args[0] == enter_autocast_node
76
+
77
+
78
+ def _replace_with_hop(node: torch.fx.Node) -> None:
79
+ assert node.op == "call_module"
80
+ graph: torch.fx.Graph = node.graph
81
+ assert graph.owning_module is not None
82
+ gm: torch.fx.GraphModule = graph.owning_module
83
+ assert isinstance(node.target, str)
84
+ sub_gm = getattr(gm, node.target)
85
+ sub_graph = sub_gm.graph
86
+ autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node)
87
+ if len(autocast_nodes) > 0:
88
+ assert len(autocast_nodes) > 1 # need at least an enter node and an exist node
89
+ enter_autocast_node = autocast_nodes[0]
90
+ exit_autocast_node = autocast_nodes[-1]
91
+ _check_valid_autocast_block(enter_autocast_node, exit_autocast_node)
92
+
93
+ _replace_with_hop_helper(node, enter_autocast_node, wrap_with_autocast)
94
+ sub_graph.erase_node(exit_autocast_node)
95
+ sub_graph.erase_node(enter_autocast_node)
96
+
97
+
98
+ def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
99
+ """
100
+ split_autocast creates a new graph module that splits the input graph module into multiple submodules
101
+ based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module.
102
+
103
+ Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are split
104
+ into a submodule. Nested autocast regions are not split.
105
+ `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well.
106
+
107
+ Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph
108
+ module. Nodes marked with the same number are grouped into the same submodule.
109
+ A # 0
110
+ enter_autocast # 1
111
+ B # 1
112
+ exit_autocast # 1
113
+ C # 2
114
+ enter_autocast # 3
115
+ D # 3
116
+ exit_autocast # 3
117
+ E # 4
118
+ """
119
+ enter_autocast_node_stack: list[torch.fx.Node] = []
120
+ first_node_after_outer_most_exit: bool = False
121
+
122
+ def node_call_back(node: torch.fx.Node) -> bool:
123
+ nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit
124
+ increment_id = False
125
+ if first_node_after_outer_most_exit or (
126
+ len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node)
127
+ ):
128
+ assert len(enter_autocast_node_stack) == 0
129
+ first_node_after_outer_most_exit = False
130
+ increment_id = True
131
+ if _is_enter_autocast_node(node):
132
+ enter_autocast_node_stack.append(node)
133
+ elif _is_exit_autocast_node(node):
134
+ assert len(enter_autocast_node_stack) > 0
135
+ last_enter_autocast_node = enter_autocast_node_stack.pop()
136
+ assert node.args[0] == last_enter_autocast_node
137
+ if len(enter_autocast_node_stack) == 0:
138
+ # next node should be in the next submodule since
139
+ # autocast block ends
140
+ first_node_after_outer_most_exit = True
141
+ return increment_id
142
+
143
+ return sequential_split(gm, node_call_back)
144
+
145
+
146
+ def _sequential_split_and_maybe_inline_subgraphs(
147
+ gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
148
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
149
+ """
150
+ Helper function for replace_autocast_with_hop_pass().
151
+ Split the graph module into multiple subgraphs based on the autocast nodes.
152
+ For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
153
+ back into the parent graph module.
154
+ Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered
155
+ as a subgraph.
156
+ """
157
+ need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes)
158
+ if not need_replacing:
159
+ return gm, graph_signature
160
+
161
+ # split_autocast returns a new graph module that could have different output
162
+ # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`.
163
+ new_gm = _split_autocast(gm)
164
+
165
+ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None:
166
+ if _is_autocast_sub_mod(node):
167
+ _replace_with_hop(node)
168
+ else:
169
+ assert node.op == "call_module"
170
+ assert isinstance(node.target, str)
171
+ node_inline_(node)
172
+
173
+ return _sequential_split_and_maybe_inline_subgraphs_helper(
174
+ new_gm, graph_signature, _maybe_inline_or_replace_with_hop
175
+ )
176
+
177
+
178
+ def replace_autocast_with_hop_pass(
179
+ gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
180
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
181
+ """
182
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
183
+ then recursively call itself on each of the submodules.
184
+ """
185
+ return _replace_with_hop_pass_helper(
186
+ gm,
187
+ graph_signature,
188
+ _sequential_split_and_maybe_inline_subgraphs,
189
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import logging
3
+ import operator
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.export._trace
8
+ from torch._ops import OpOverload
9
+ from torch.ao.quantization.fx._decomposed import (
10
+ dequantize_per_channel,
11
+ dequantize_per_tensor,
12
+ quantize_per_tensor,
13
+ )
14
+ from torch.ao.quantization.utils import calculate_qmin_qmax
15
+ from torch.fx.graph_module import _assign_attr
16
+
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+ # Those values will need to be carried over multiple operators.
21
+ _INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None
22
+ _SCALE: Optional[Union[float, torch.fx.Node]] = None
23
+ _ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None
24
+
25
+
26
+ def int_to_valid_dtype(val: int) -> torch.dtype:
27
+ from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import.
28
+
29
+ if isinstance(val, torch.dtype):
30
+ return val
31
+ dtype = _TORCH_ENUM_TO_DTYPE[val]
32
+ if dtype == torch.quint8:
33
+ return torch.uint8
34
+ elif dtype == torch.qint8:
35
+ return torch.int8
36
+ return dtype
37
+
38
+
39
+ def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node:
40
+ return gm.graph.call_function(int_to_valid_dtype, (val,))
41
+
42
+
43
+ def insert_quantized_node(
44
+ gm: torch.fx.GraphModule,
45
+ val_node: torch.fx.Node,
46
+ scale_node: Union[float, torch.fx.Node],
47
+ zero_point_node: Union[float, torch.fx.Node],
48
+ qmin_node: Union[float, int, torch.fx.Node],
49
+ qmax_node: Union[float, int, torch.fx.Node],
50
+ dtype_node: Union[torch.dtype, torch.fx.Node],
51
+ qscheme: Optional[torch.qscheme],
52
+ ) -> torch.fx.Node:
53
+ return gm.graph.call_function(
54
+ quantize_per_tensor,
55
+ (
56
+ val_node,
57
+ scale_node,
58
+ zero_point_node,
59
+ qmin_node,
60
+ qmax_node,
61
+ dtype_node,
62
+ ),
63
+ )
64
+
65
+
66
+ def get_dequantized(
67
+ val: torch.Tensor,
68
+ scale: Union[float, torch.Tensor],
69
+ zero_point: Union[float, torch.Tensor],
70
+ qmin: Union[float, int],
71
+ qmax: Union[float, int],
72
+ dtype: torch.dtype,
73
+ axis: Optional[int],
74
+ qscheme: Optional[torch.qscheme],
75
+ ) -> torch.Tensor:
76
+ if qscheme is torch.per_tensor_affine:
77
+ return dequantize_per_tensor(
78
+ val,
79
+ scale, # type: ignore[arg-type]
80
+ zero_point, # type: ignore[arg-type]
81
+ qmin, # type: ignore[arg-type]
82
+ qmax, # type: ignore[arg-type]
83
+ dtype,
84
+ )
85
+ elif qscheme is torch.per_channel_affine:
86
+ return dequantize_per_channel(
87
+ val,
88
+ scale, # type: ignore[arg-type]
89
+ zero_point, # type: ignore[arg-type]
90
+ axis, # type: ignore[arg-type]
91
+ qmin, # type: ignore[arg-type]
92
+ qmax, # type: ignore[arg-type]
93
+ dtype,
94
+ )
95
+ else:
96
+ raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
97
+
98
+
99
+ def insert_dequantized_node(
100
+ gm: torch.fx.GraphModule,
101
+ val_node: torch.fx.Node,
102
+ scale_node: Union[float, torch.fx.Node],
103
+ zero_point_node: Union[float, torch.fx.Node],
104
+ qmin_node: Union[float, int, torch.fx.Node],
105
+ qmax_node: Union[float, int, torch.fx.Node],
106
+ dtype_node: Union[torch.dtype, torch.fx.Node],
107
+ axis_node: Optional[Union[int, torch.fx.Node]],
108
+ qscheme: Optional[torch.qscheme],
109
+ ) -> torch.fx.Node:
110
+ if qscheme is torch.per_tensor_affine:
111
+ return gm.graph.call_function(
112
+ dequantize_per_tensor,
113
+ (
114
+ val_node,
115
+ scale_node,
116
+ zero_point_node,
117
+ qmin_node,
118
+ qmax_node,
119
+ dtype_node,
120
+ ),
121
+ )
122
+ elif qscheme is torch.per_channel_affine:
123
+ return gm.graph.call_function(
124
+ dequantize_per_channel,
125
+ (
126
+ val_node,
127
+ scale_node,
128
+ zero_point_node,
129
+ axis_node,
130
+ qmin_node,
131
+ qmax_node,
132
+ dtype_node,
133
+ ),
134
+ )
135
+ else:
136
+ raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
137
+
138
+
139
+ def get_qmin_qmax(dtype: torch.dtype) -> tuple[Union[int, float], Union[int, float]]:
140
+ return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type]
141
+
142
+
143
+ def insert_qmin_qmax_node(
144
+ gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node]
145
+ ) -> tuple[torch.fx.Node, torch.fx.Node]:
146
+ q_min_max_node = gm.graph.call_function(
147
+ calculate_qmin_qmax, (None, None, False, dtype_node, False)
148
+ )
149
+ qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0))
150
+ qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1))
151
+ return qmin_node, qmax_node
152
+
153
+
154
+ def get_script_object(
155
+ gm: torch.nn.Module, node: torch.fx.Node
156
+ ) -> torch._C.ScriptObject:
157
+ assert isinstance(node, torch.fx.Node)
158
+ assert node.op == "get_attr"
159
+ attr_name = node.target
160
+ assert isinstance(attr_name, str)
161
+
162
+ mod = gm
163
+ for attr in attr_name.split("."):
164
+ mod = getattr(mod, attr)
165
+ assert isinstance(mod, torch._C.ScriptObject)
166
+ return mod
167
+
168
+
169
+ def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
170
+ gm: torch.fx.GraphModule,
171
+ param_node: torch.fx.Node,
172
+ ) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
173
+ """Directly inline tensor from a get_attr fx node."""
174
+ mod = get_script_object(gm, param_node)
175
+ w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined]
176
+ w_attr_name, b_attr_name = (
177
+ f"dequantized_{param_node.target}_w",
178
+ f"dequantized_{param_node.target}_b",
179
+ )
180
+ return insert_weight_and_bias_get_attr_node(
181
+ gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
182
+ )
183
+
184
+
185
+ def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
186
+ gm: torch.fx.GraphModule,
187
+ get_attr_to_weight_node: torch.fx.Node,
188
+ get_attr_to_bias_node: Optional[torch.fx.Node],
189
+ ) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
190
+ assert isinstance(get_attr_to_weight_node.target, str)
191
+ w_qtensor = getattr(gm, get_attr_to_weight_node.target)
192
+ w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w"
193
+
194
+ if get_attr_to_bias_node is not None:
195
+ assert isinstance(get_attr_to_bias_node.target, str)
196
+ b_qtensor = getattr(gm, get_attr_to_bias_node.target)
197
+ b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b"
198
+ else:
199
+ b_qtensor, b_attr_name = None, ""
200
+
201
+ return insert_weight_and_bias_get_attr_node(
202
+ gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
203
+ )
204
+
205
+
206
+ def insert_weight_and_bias_get_attr_node(
207
+ gm: torch.fx.GraphModule,
208
+ w_qtensor: torch.Tensor,
209
+ b_qtensor: Optional[torch.Tensor],
210
+ w_attr_name: str,
211
+ b_attr_name: str,
212
+ ) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
213
+ w_tensor = get_tensor_from_qtensor(w_qtensor)
214
+ _assign_attr(w_tensor, gm, w_attr_name)
215
+ w_tensor_attr = gm.graph.get_attr(w_attr_name)
216
+
217
+ if b_qtensor is not None:
218
+ b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False)
219
+ _assign_attr(b_tensor, gm, b_attr_name)
220
+ b_tensor_attr = gm.graph.get_attr(b_attr_name)
221
+ else:
222
+ b_tensor_attr = None
223
+
224
+ return w_tensor_attr, b_tensor_attr
225
+
226
+
227
+ def get_tensor_from_qtensor(
228
+ qtensor: torch.Tensor, dequant: bool = True
229
+ ) -> torch.Tensor:
230
+ # Manual conversion because qint8 is not used anymore.
231
+ if qtensor.dtype in [torch.qint8, torch.quint8]:
232
+ tensor = qtensor.int_repr()
233
+ else:
234
+ tensor = qtensor
235
+
236
+ # Weights need dequantization with scaling and zero_point adjustment, but
237
+ # bias does not need that.
238
+ if dequant:
239
+ qscheme = qtensor.qscheme()
240
+ if qscheme == torch.per_channel_affine:
241
+ scale, zero_point, axis = (
242
+ qtensor.q_per_channel_scales(),
243
+ qtensor.q_per_channel_zero_points(),
244
+ qtensor.q_per_channel_axis(),
245
+ )
246
+ else:
247
+ scale, zero_point, axis = (
248
+ qtensor.q_scale(), # type: ignore[assignment]
249
+ qtensor.q_zero_point(), # type: ignore[assignment]
250
+ None,
251
+ )
252
+ dtype = tensor.dtype
253
+ qmin, qmax = get_qmin_qmax(dtype)
254
+ return get_dequantized(
255
+ tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme
256
+ )
257
+ return tensor
258
+
259
+
260
+ def insert_fused_activation_node(
261
+ gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node
262
+ ) -> torch.fx.Node:
263
+ if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]:
264
+ fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,))
265
+ return fx_node
266
+
267
+
268
+ def _conv1d_op_with_squeeze(
269
+ inp: torch.Tensor,
270
+ weight: torch.Tensor,
271
+ bias: Optional[torch.Tensor],
272
+ stride: list[int],
273
+ padding: list[int],
274
+ dilation: list[int],
275
+ groups: int,
276
+ ) -> torch.Tensor:
277
+ # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
278
+ # operations before and after the conv2d operation to match the dimension of weights.
279
+ # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950
280
+ s_inp = torch.ops.aten.unsqueeze(inp, 2)
281
+ conv1d_res = torch.ops.aten.conv2d(
282
+ s_inp,
283
+ weight,
284
+ bias,
285
+ stride,
286
+ padding,
287
+ dilation,
288
+ groups,
289
+ )
290
+ uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2)
291
+ return uns_conv1d_res
292
+
293
+
294
+ def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
295
+ """Conv specific transformation function."""
296
+ assert isinstance(node.target, torch._ops.OpOverload)
297
+ opname = node.target._opname
298
+ scale_node, zero_point_node = node.args[2], node.args[3]
299
+
300
+ op_f = (
301
+ torch.ops.aten.conv2d
302
+ if opname in ["conv2d", "conv2d_relu"]
303
+ else _conv1d_op_with_squeeze
304
+ )
305
+
306
+ inp_node, param_node = node.args[0], node.args[1]
307
+ assert isinstance(inp_node, torch.fx.Node)
308
+ assert isinstance(param_node, torch.fx.Node)
309
+
310
+ if param_node.op == "call_function":
311
+ # Using Conv2dPrepackParam from conv_prepack.
312
+ # We directly skip the packing call and inline weights and bias.
313
+ w_node, b_node = param_node.args[0], param_node.args[1]
314
+ assert isinstance(w_node, torch.fx.Node)
315
+ assert b_node is None or isinstance(b_node, torch.fx.Node)
316
+ (
317
+ param_0,
318
+ param_1,
319
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
320
+ gm, w_node, b_node
321
+ )
322
+ op_res_node = gm.graph.call_function(
323
+ op_f, (inp_node, param_0, param_1, *param_node.args[2:])
324
+ )
325
+ else:
326
+ # Using ConvPrepackedParam.
327
+ param = get_script_object(gm, param_node)
328
+ (
329
+ param_0,
330
+ param_1,
331
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
332
+ gm, param_node
333
+ ) # type: ignore[assignment]
334
+ op_res_node = gm.graph.call_function(
335
+ op_f,
336
+ (
337
+ inp_node,
338
+ param_0,
339
+ param_1,
340
+ param.stride(), # type: ignore[attr-defined]
341
+ param.padding(), # type: ignore[attr-defined]
342
+ param.dilation(), # type: ignore[attr-defined]
343
+ param.groups(), # type: ignore[attr-defined]
344
+ ),
345
+ )
346
+ return op_res_node, scale_node, zero_point_node
347
+
348
+
349
+ def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
350
+ """Linear specific transformation function."""
351
+ scale_node, zero_point_node = node.args[2], node.args[3]
352
+
353
+ inp_node, param_node = node.args[0], node.args[1]
354
+ assert isinstance(inp_node, torch.fx.Node)
355
+ assert isinstance(param_node, torch.fx.Node)
356
+
357
+ if param_node.op == "call_function":
358
+ # Using LinearPrepackParam from linear_prepack.
359
+ # We directly skip the packing call and inline weights and bias.
360
+ w_node, b_node = param_node.args[0], param_node.args[1]
361
+ assert isinstance(w_node, torch.fx.Node)
362
+ assert b_node is None or isinstance(b_node, torch.fx.Node)
363
+ (
364
+ param_0,
365
+ param_1,
366
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
367
+ gm, w_node, b_node
368
+ )
369
+ op_res_node = gm.graph.call_function(
370
+ torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:])
371
+ )
372
+ else:
373
+ # Using LinearPackedParams.
374
+ (
375
+ param_0,
376
+ param_1,
377
+ ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
378
+ gm, param_node
379
+ ) # type: ignore[assignment]
380
+ op_res_node = gm.graph.call_function(
381
+ torch.ops.aten.linear, (inp_node, param_0, param_1)
382
+ )
383
+ return op_res_node, scale_node, zero_point_node
384
+
385
+
386
+ def _transform_op_where_last_two_arguments_are_scale_and_zero_point(
387
+ gm: torch.fx.GraphModule, node: torch.fx.Node
388
+ ):
389
+ """
390
+ This transformation function can be used for function where the last two
391
+ parameters are scale and zero point. Additionally, the function's parameters
392
+ do not need any unpacking.
393
+ """
394
+ to_standard_op = {
395
+ "mul": torch.ops.aten.mul,
396
+ "mul_relu": torch.ops.aten.mul,
397
+ "add": torch.ops.aten.add,
398
+ "add_relu": torch.ops.aten.add,
399
+ "softmax": torch.ops.aten.softmax,
400
+ "cat": torch.ops.aten.cat,
401
+ "hardswish": torch.ops.aten.hardswish,
402
+ }
403
+
404
+ assert isinstance(node.target, torch._ops.OpOverload)
405
+ opname, args = node.target._opname, node.args
406
+ scale_node, zero_point_node = args[-2], args[-1]
407
+ op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2]))
408
+ return op_res_node, scale_node, zero_point_node
409
+
410
+
411
+ def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node):
412
+ """Transform scalar overload for basic arithmetic."""
413
+ to_standard_op = {
414
+ "mul": torch.ops.aten.mul.Scalar,
415
+ "add": torch.ops.aten.add.Scalar,
416
+ }
417
+ assert isinstance(node.target, torch._ops.OpOverload)
418
+ opname, args = node.target._opname, node.args
419
+ op_res_node = gm.graph.call_function(to_standard_op[opname], args)
420
+ return op_res_node, _SCALE, _ZERO_POINT
421
+
422
+
423
+ def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
424
+ """
425
+ Transformation for functions under prepacked namespace, where they share
426
+ the same handling logic that [...]OpContext contains all parameters.
427
+ """
428
+ assert isinstance(node.target, torch._ops.OpOverload)
429
+ opname, args = node.target._opname, node.args
430
+ op_f = None
431
+ if opname == "conv2d_clamp_run":
432
+ op_f = torch.ops.aten.conv2d
433
+ elif opname == "linear_clamp_run":
434
+ op_f = torch.ops.aten.linear
435
+ else:
436
+ raise RuntimeError(f"Invalid operator {opname}")
437
+
438
+ assert isinstance(args[1], torch.fx.Node)
439
+ so = get_script_object(gm, args[1])
440
+
441
+ func_args = []
442
+ func_args += [args[0]]
443
+ func_args += so.unpack()[:2] # type: ignore[attr-defined]
444
+ if opname == "conv2d_clamp_run":
445
+ func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]
446
+
447
+ op_res_node = gm.graph.call_function(op_f, tuple(func_args))
448
+ return op_res_node
449
+
450
+
451
+ def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
452
+ args = node.args
453
+ scale_node, zero_point_node = args[-2], args[-1]
454
+ op_res_node = gm.graph.call_function(
455
+ torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3])
456
+ )
457
+ op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0))
458
+ return op_res_node, scale_node, zero_point_node
459
+
460
+
461
+ def fx_transform_quantized_op_to_standard_op(
462
+ gm: torch.fx.GraphModule, node: torch.fx.Node
463
+ ) -> torch.fx.Node:
464
+ global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE
465
+
466
+ assert isinstance(node.target, torch._ops.OpOverload)
467
+ opname, overload = node.target._opname, node.target._overloadname
468
+
469
+ key = f"{opname}.{overload}"
470
+ opname_to_transform_f = {
471
+ "conv1d.new": _transform_conv_with_packedparam,
472
+ "conv1d_relu.new": _transform_conv_with_packedparam,
473
+ "conv1d.default": _transform_conv_with_packedparam,
474
+ "conv1d_relu.default": _transform_conv_with_packedparam,
475
+ "conv2d.new": _transform_conv_with_packedparam,
476
+ "conv2d_relu.new": _transform_conv_with_packedparam,
477
+ "conv2d.default": _transform_conv_with_packedparam,
478
+ "conv2d_relu.default": _transform_conv_with_packedparam,
479
+ "linear.default": _transform_linear_with_packedparam,
480
+ "linear_relu.default": _transform_linear_with_packedparam,
481
+ "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
482
+ "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
483
+ "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
484
+ "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
485
+ "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
486
+ "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
487
+ "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
488
+ "batch_norm2d.default": _transform_batch_norm,
489
+ "mul.Scalar": _transform_scalar_arithmetic,
490
+ "add.Scalar": _transform_scalar_arithmetic,
491
+ }
492
+
493
+ if f"{key}" not in opname_to_transform_f:
494
+ raise RuntimeError(f"Unsupported quantized op during transformation: {key}")
495
+
496
+ op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node)
497
+
498
+ # Add fused activation layer.
499
+ op_res_node = insert_fused_activation_node(gm, opname, op_res_node)
500
+ _SCALE, _ZERO_POINT = scale_node, zero_point_node
501
+
502
+ assert _INPUT_Q_DTYPE is not None
503
+ qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE)
504
+ q_fx_node = insert_quantized_node(
505
+ gm,
506
+ op_res_node,
507
+ scale_node,
508
+ zero_point_node,
509
+ qmin_node,
510
+ qmax_node,
511
+ _INPUT_Q_DTYPE,
512
+ torch.per_tensor_affine,
513
+ )
514
+ dq_fx_node = insert_dequantized_node(
515
+ gm,
516
+ q_fx_node,
517
+ scale_node,
518
+ zero_point_node,
519
+ qmin_node,
520
+ qmax_node,
521
+ _INPUT_Q_DTYPE,
522
+ None,
523
+ torch.per_tensor_affine,
524
+ )
525
+ return dq_fx_node
526
+
527
+
528
+ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
529
+ """
530
+ Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with
531
+ PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv).
532
+
533
+ Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y
534
+
535
+ After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y
536
+
537
+ (qd == quantized_decomposed library, q = quantize, dq = dequantize)
538
+ ^
539
+ |
540
+ getattr(w), getattr(b) from Conv2dParamPrepack
541
+
542
+ During each iteration, the transformation spits out the transformed operator, its quantized output,
543
+ and its dequantized value together. We did this because dequantization need to use the
544
+ scale and zero point parameters from the quantization to recover the approximate original value. After each
545
+ iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear).
546
+
547
+ For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject.
548
+ During the transformation, we unpack those objects, get their dequantized tensor, populate those
549
+ as attributes to the module, and use getattr to access them.
550
+
551
+ One exception in the transformation is conv_prepack and linear_prepack. Those calls pack
552
+ weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls.
553
+ During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the
554
+ quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
555
+ to the operator by converting them to a getattr fx.node.
556
+
557
+ For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
558
+ without the need of doing de/quantization.
559
+
560
+ Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
561
+ data type, which is the same across the entire program, but it only shows up in the very first quantization
562
+ call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
563
+ """
564
+
565
+ global _INPUT_Q_DTYPE
566
+
567
+ quantized = False
568
+
569
+ last_quantized_node = None
570
+ # pyrefly: ignore [bad-assignment]
571
+ for node in gm.graph.nodes:
572
+ if isinstance(node.target, OpOverload):
573
+ with gm.graph.inserting_before(node):
574
+ namespace, opname = node.target.namespace, node.target._opname
575
+ if namespace == "quantized" and opname not in [
576
+ "conv_prepack",
577
+ "linear_prepack",
578
+ ]:
579
+ quantized = True
580
+ fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
581
+ node.replace_all_uses_with(fx_node)
582
+ last_quantized_node = fx_node
583
+ elif namespace == "prepacked":
584
+ quantized = True
585
+ fx_node = _transform_prepacked_op(gm, node)
586
+ node.replace_all_uses_with(fx_node)
587
+ last_quantized_node = fx_node
588
+ elif namespace == "aten" and opname == "quantize_per_tensor":
589
+ inp_node, scale_node, zero_point_node, dtype_node = node.args
590
+ dtype_node = fx_enum_to_dtype(gm, dtype_node)
591
+ _INPUT_Q_DTYPE = dtype_node
592
+ qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node)
593
+ q_fx_node = insert_quantized_node(
594
+ gm,
595
+ inp_node,
596
+ scale_node,
597
+ zero_point_node,
598
+ qmin_node,
599
+ qmax_node,
600
+ dtype_node,
601
+ torch.per_tensor_affine,
602
+ )
603
+ dq_fx_node = insert_dequantized_node(
604
+ gm,
605
+ q_fx_node,
606
+ scale_node,
607
+ zero_point_node,
608
+ qmin_node,
609
+ qmax_node,
610
+ dtype_node,
611
+ None,
612
+ torch.per_tensor_affine,
613
+ )
614
+ node.replace_all_uses_with(dq_fx_node)
615
+ last_quantized_node = dq_fx_node
616
+ elif namespace == "aten" and opname == "dequantize":
617
+ assert last_quantized_node is not None
618
+ node.replace_all_uses_with(last_quantized_node)
619
+ else:
620
+ last_quantized_node = node
621
+
622
+ # Post-processing again to remove legacy ScriptObjects and quantizated tensors
623
+ # stored as attributes or in the buffer. This is used to clean up the GraphModule
624
+ # to not trigger tracing errors like missing __obj_flatten__ functions.
625
+ def _clean_attr(mod: torch.nn.Module):
626
+ for submod in mod.modules():
627
+ attr_names_to_clean = set()
628
+ for k, v in submod.__dict__.items():
629
+ if isinstance(v, torch.ScriptObject):
630
+ attr_names_to_clean.add(k)
631
+ if k == "_buffers":
632
+ buffer_name_to_clean = set()
633
+ # pyrefly: ignore [missing-attribute]
634
+ for b_name, b_value in v.items():
635
+ if isinstance(b_value, torch.Tensor) and b_value.dtype in [
636
+ torch.qint8,
637
+ torch.quint8,
638
+ ]:
639
+ buffer_name_to_clean.add(b_name)
640
+ for b_name in buffer_name_to_clean:
641
+ # pyrefly: ignore [missing-attribute]
642
+ v.pop(b_name, None)
643
+ for attr_name in attr_names_to_clean:
644
+ delattr(submod, attr_name)
645
+
646
+ if quantized:
647
+ """
648
+ TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily
649
+ bypass test cases.
650
+
651
+ The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing
652
+ will throw errors. However, the current way of SetAttr does inplace update to attributes, so
653
+ this pass regard them as dead code and remove them. Below is an example of GraphModule before
654
+ and after the dead code elimination pass.
655
+
656
+ class GraphModule(torch.nn.Module):
657
+ def forward(self, x_1):
658
+ # No stacktrace found for following nodes
659
+ data = self.data; data = None
660
+ data_1 = self.data
661
+ add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None
662
+ data_2 = self.data
663
+ copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None
664
+ data_3 = self.data
665
+ add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
666
+ return add_tensor_1
667
+
668
+ class GraphModule(torch.nn.Module):
669
+ def forward(self, x_1):
670
+ # No stacktrace found for following nodes
671
+ data_3 = self.data
672
+ add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
673
+ return add_tensor_1
674
+ """
675
+ gm.graph.eliminate_dead_code()
676
+ _clean_attr(gm)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
8
+
9
+ from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split
10
+ from .replace_with_hop_pass_util import (
11
+ _replace_with_hop_helper,
12
+ _replace_with_hop_pass_helper,
13
+ _sequential_split_and_maybe_inline_subgraphs_helper,
14
+ )
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from torch.export.graph_signature import ExportGraphSignature
19
+
20
+
21
+ def _is_set_grad_enabled_node(node: torch.fx.Node) -> torch.fx.Node | bool:
22
+ return (
23
+ node
24
+ and node.op == "call_function"
25
+ and node.target is torch._C._set_grad_enabled
26
+ )
27
+
28
+
29
+ def _is_set_grad_enabled_sub_mod(
30
+ node: torch.fx.Node, omit_if_same_with_ambient: bool = False
31
+ ) -> bool | torch.Tensor:
32
+ if node.op == "call_module":
33
+ assert isinstance(node.target, str)
34
+ subgm = getattr(node.graph.owning_module, node.target)
35
+ first_non_ph = nodes_first(
36
+ subgm.graph.nodes, lambda node: node.op != "placeholder"
37
+ )
38
+ if (
39
+ first_non_ph
40
+ and first_non_ph.op == "call_function"
41
+ and first_non_ph.target is torch._C._set_grad_enabled
42
+ ):
43
+ return (
44
+ first_non_ph.args[0] != torch.is_grad_enabled()
45
+ if omit_if_same_with_ambient
46
+ else True
47
+ )
48
+ return False
49
+
50
+
51
+ def _replace_with_hop(node: torch.fx.Node) -> None:
52
+ assert node.op == "call_module"
53
+ graph: torch.fx.Graph = node.graph
54
+ assert graph.owning_module is not None
55
+ gm: torch.fx.GraphModule = graph.owning_module
56
+ assert isinstance(node.target, str)
57
+ sub_gm = getattr(gm, node.target)
58
+ sub_graph = sub_gm.graph
59
+ set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
60
+ if len(set_grad_nodes) > 0:
61
+ assert len(set_grad_nodes) == 1
62
+ set_grad_node = set_grad_nodes[0]
63
+ _replace_with_hop_helper(node, set_grad_node, wrap_with_set_grad_enabled)
64
+ sub_graph.erase_node(set_grad_node)
65
+
66
+
67
+ def _remove_set_grad_and_inline(node: torch.fx.Node) -> None:
68
+ assert node.op == "call_module"
69
+ graph: torch.fx.Graph = node.graph
70
+ assert graph.owning_module is not None
71
+ gm: torch.fx.GraphModule = graph.owning_module
72
+ assert isinstance(node.target, str)
73
+ sub_gm = getattr(gm, node.target)
74
+ sub_graph = sub_gm.graph
75
+ nodes_map(
76
+ sub_graph.nodes,
77
+ lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
78
+ )
79
+ node_inline_(node)
80
+
81
+
82
+ def _sequential_split_and_maybe_inline_subgraphs(
83
+ gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
84
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
85
+ """
86
+ Helper function for replace_set_grad_with_hop_pass().
87
+ Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
88
+ For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
89
+ back into the parent graph module.
90
+ """
91
+ need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes)
92
+ if not need_replacing:
93
+ return gm, graph_signature
94
+
95
+ # sequential_split returns a new graph module that could have different output
96
+ # args names. We need to fix the graph signature.
97
+ new_gm = sequential_split(gm, _is_set_grad_enabled_node)
98
+
99
+ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
100
+ if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
101
+ _replace_with_hop(node)
102
+ else:
103
+ _remove_set_grad_and_inline(node)
104
+
105
+ return _sequential_split_and_maybe_inline_subgraphs_helper(
106
+ new_gm, graph_signature, _maybe_inline_or_replace_with_hop
107
+ )
108
+
109
+
110
+ def replace_set_grad_with_hop_pass(
111
+ gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
112
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
113
+ """
114
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
115
+ then recursively call itself on each of the submodules.
116
+ """
117
+ return _replace_with_hop_pass_helper(
118
+ gm,
119
+ graph_signature,
120
+ _sequential_split_and_maybe_inline_subgraphs,
121
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch._export.error import InternalError
6
+ from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
7
+ from torch._ops import HigherOrderOperator, OpOverload
8
+
9
+
10
+ __all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
11
+
12
+
13
+ _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = {
14
+ torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
15
+ }
16
+
17
+
18
+ def is_view_op(schema: torch._C.FunctionSchema) -> bool:
19
+ if len(schema.arguments) == 0:
20
+ return False
21
+ alias_info = schema.arguments[0].alias_info
22
+ return (alias_info is not None) and (not alias_info.is_write)
23
+
24
+
25
+ def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
26
+ if is_view_op(schema) and schema.name.startswith("aten::"):
27
+ view_op_name = schema.name.split("::")[1]
28
+ view_op_overload = (
29
+ schema.overload_name if schema.overload_name != "" else "default"
30
+ )
31
+ view_copy_op_name = view_op_name + "_copy"
32
+ if not hasattr(torch.ops.aten, view_copy_op_name):
33
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
34
+
35
+ view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
36
+
37
+ if not hasattr(view_copy_op_overload_packet, view_op_overload):
38
+ raise InternalError(f"{schema.name} is missing a view_copy variant")
39
+
40
+ return getattr(view_copy_op_overload_packet, view_op_overload)
41
+
42
+ return None
43
+
44
+
45
+ class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
46
+ """
47
+ Our backend expects pure functional operators. For efficiency
48
+ purposes, we keep view ops around while functionalizing the exported
49
+ program. This pass replaces view ops with view copy ops for backends that
50
+ need AOT memory planning.
51
+ """
52
+
53
+ def call_operator(self, op, args, kwargs, meta):
54
+ if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
55
+ return super().call_operator(
56
+ (_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
57
+ )
58
+
59
+ if isinstance(op, HigherOrderOperator):
60
+ return super().call_operator(op, args, kwargs, meta)
61
+
62
+ if view_copy_op := get_view_copy_of_view_op(op._schema):
63
+ return super().call_operator(view_copy_op, args, kwargs, meta)
64
+
65
+ return super().call_operator(op, args, kwargs, meta)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from __future__ import annotations
3
+
4
+ import contextlib
5
+ import copy
6
+ import operator
7
+ from typing import TYPE_CHECKING
8
+
9
+ import torch
10
+
11
+ from ..utils import node_replace_, nodes_map
12
+
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Callable
16
+
17
+ from torch._ops import HigherOrderOperator
18
+ from torch.export.graph_signature import ExportGraphSignature
19
+
20
+
21
+ def _replace_with_hop_helper(
22
+ node: torch.fx.Node,
23
+ enter_block_node: torch.fx.Node,
24
+ wrap_hoo: HigherOrderOperator,
25
+ ) -> None:
26
+ graph: torch.fx.Graph = node.graph
27
+ assert graph.owning_module is not None
28
+ gm: torch.fx.GraphModule = graph.owning_module
29
+ assert isinstance(node.target, str)
30
+ sub_gm = getattr(gm, node.target)
31
+
32
+ def set_hoo_node_meta(call_func_node):
33
+ call_func_node.meta["nn_module_stack"] = copy.copy(
34
+ enter_block_node.meta.get("nn_module_stack", {})
35
+ )
36
+ call_func_node.meta["torch_fn"] = (
37
+ f"{wrap_hoo.__name__}",
38
+ # pyrefly: ignore [missing-attribute]
39
+ f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
40
+ )
41
+ if isinstance(output_args, (tuple, list)):
42
+ call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args)
43
+ elif isinstance(output_args, torch.fx.Node):
44
+ call_func_node.meta["val"] = (output_args.meta["val"],)
45
+
46
+ with graph.inserting_before(node):
47
+ get_attr_node = graph.get_attr(node.target)
48
+ get_attr_node.meta["nn_module_stack"] = copy.copy(
49
+ enter_block_node.meta.get("nn_module_stack", {})
50
+ )
51
+ output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
52
+ # Split_module pass intentionally doesn't add output node
53
+ # if the graph doesn't return anything.
54
+ # TODO (tmanlaibaatar) Figure out if this is right behaviour
55
+ # for split_module
56
+ if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
57
+ output_node = None
58
+ if output_node is not None:
59
+ assert len(output_node.args) == 1
60
+ output_args = output_node.args[0]
61
+ enter_block_node_args = enter_block_node.args
62
+ if isinstance(output_args, (tuple, list)):
63
+ call_func_node = graph.call_function(
64
+ wrap_hoo,
65
+ (*enter_block_node_args, get_attr_node, *node.args),
66
+ {},
67
+ )
68
+ # Create the metadata
69
+ set_hoo_node_meta(call_func_node)
70
+ node_replace_(node, call_func_node)
71
+
72
+ # Rename the name of getitem nodes to the actual name of its contents
73
+ # for passing verifier and better readability, also propagate metadata
74
+ for get_item_node in call_func_node.users:
75
+ idx: int = get_item_node.args[1] # type: ignore[assignment]
76
+ output_node = output_args[idx]
77
+ get_item_node._rename(output_node.name)
78
+ get_item_node.meta = output_node.meta
79
+
80
+ elif isinstance(output_args, torch.fx.Node):
81
+ call_func_node = graph.create_node(
82
+ "call_function",
83
+ wrap_hoo,
84
+ (*enter_block_node_args, get_attr_node, *node.args),
85
+ {},
86
+ output_args.name,
87
+ )
88
+ # Modify the subgraph to output a singleton list.
89
+ output_node.args = ((output_args,),)
90
+ # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph.
91
+ get_item_node = graph.create_node(
92
+ "call_function",
93
+ operator.getitem,
94
+ (call_func_node, 0),
95
+ {},
96
+ )
97
+ # Create the metadata
98
+ get_item_node.meta = output_args.meta
99
+ set_hoo_node_meta(call_func_node)
100
+ node_replace_(node, get_item_node)
101
+ else:
102
+ raise NotImplementedError(
103
+ f"replace_with_hop_pass doesn't support output type {type(output_args)}"
104
+ )
105
+ else:
106
+ # TODO (shangdiy): remove this line, since the export graph can be non-functional
107
+ node.graph.erase_node(node)
108
+
109
+
110
+ def _sequential_split_and_maybe_inline_subgraphs_helper(
111
+ new_gm: torch.fx.GraphModule,
112
+ graph_signature: ExportGraphSignature | None,
113
+ maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None],
114
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
115
+ """
116
+ Helper function for replacing graph nodse with higher order nodes.
117
+ For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls
118
+ back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`.
119
+ """
120
+ # new_gm is a new graph module that could have different output args names.
121
+ # We need to fix the graph signature.
122
+ replace_ctx = contextlib.nullcontext()
123
+ new_signature = None
124
+ if graph_signature is not None:
125
+ # Cannot deep copy a real ScriptObject, which is referenced
126
+ # in the FakeScriptObject. Copy should be good enough to guard
127
+ # against accidental mutation to original graph_signature.
128
+ new_signature = copy.copy(graph_signature)
129
+ new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output")))
130
+ assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len(
131
+ new_signature.output_specs
132
+ )
133
+ for arg_node, out_spec in zip(
134
+ new_gm_out_node.args[0], new_signature.output_specs
135
+ ):
136
+ if arg_node is None:
137
+ assert out_spec.arg.value is None # type: ignore[union-attr]
138
+ elif (
139
+ isinstance(arg_node, torch.fx.Node)
140
+ and out_spec.arg.name != arg_node.name
141
+ ):
142
+ out_spec.arg.name = arg_node.name
143
+
144
+ replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment]
145
+
146
+ with replace_ctx:
147
+ nodes_map(
148
+ list(new_gm.graph.nodes),
149
+ lambda node: (
150
+ maybe_inline_or_replace_with_hop(node)
151
+ if node.op == "call_module"
152
+ else node
153
+ ),
154
+ )
155
+ new_gm.recompile()
156
+ new_gm.graph.lint()
157
+ return new_gm, new_signature
158
+
159
+
160
+ def _replace_with_hop_pass_helper(
161
+ gm: torch.fx.GraphModule,
162
+ graph_signature: ExportGraphSignature | None,
163
+ sequential_split_and_maybe_inline_subgraphs: Callable[
164
+ [torch.fx.GraphModule, ExportGraphSignature | None],
165
+ tuple[torch.fx.GraphModule, ExportGraphSignature | None],
166
+ ],
167
+ ) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
168
+ """
169
+ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
170
+ then recursively call itself on each of the submodules.
171
+ """
172
+ new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs(
173
+ gm, graph_signature
174
+ )
175
+ # recursively call
176
+ for node in new_gm.graph.nodes:
177
+ if node.op == "get_attr":
178
+ subgm = getattr(new_gm, node.target)
179
+ if not isinstance(subgm, torch.fx.GraphModule):
180
+ continue
181
+ new_subgm, _ = _replace_with_hop_pass_helper(
182
+ subgm,
183
+ None,
184
+ sequential_split_and_maybe_inline_subgraphs,
185
+ )
186
+ setattr(new_gm, node.target, new_subgm)
187
+
188
+ new_gm.recompile()
189
+ new_gm.graph.lint()
190
+ return new_gm, new_signature
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Optional, Union
3
+
4
+ import torch
5
+ from torch._dynamo.exc import UserError, UserErrorType
6
+ from torch.export.dynamic_shapes import (
7
+ _check_dynamic_shapes,
8
+ _DerivedDim,
9
+ _DimHint,
10
+ _tree_map_with_path,
11
+ Dim,
12
+ )
13
+ from torch.utils._pytree import tree_map
14
+
15
+ from .serialize import _dataclass_to_dict
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class RootDim:
20
+ """
21
+ This represents a Dim object.
22
+ """
23
+
24
+ min: int
25
+ max: Union[int, None]
26
+ derived: list[str]
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class DynamicShapesSpec:
31
+ """
32
+ This stores a dynamic_shapes spec for de/serialization.
33
+ """
34
+
35
+ dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
36
+ dims: dict[str, RootDim]
37
+
38
+
39
+ def _postprocess_serialized_shapes(
40
+ dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
41
+ dims: dict[str, dict[str, Union[int, list[str], None]]],
42
+ to_dict: Optional[bool] = False,
43
+ ) -> Union[DynamicShapesSpec, dict[str, Any]]:
44
+ """
45
+ Sorts dims and dumps to dictionary format.
46
+ """
47
+ from torch.utils._sympy.numbers import int_oo
48
+
49
+ dims = {
50
+ k: RootDim(
51
+ min=v["min"], # type: ignore[arg-type]
52
+ max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type]
53
+ derived=sorted(v["derived"]), # type: ignore[arg-type]
54
+ )
55
+ for k, v in sorted(dims.items())
56
+ }
57
+ # pyrefly: ignore [bad-argument-type]
58
+ spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
59
+ if to_dict:
60
+ return _dataclass_to_dict(spec)
61
+ else:
62
+ return spec
63
+
64
+
65
+ def _dump_dynamic_shapes(
66
+ dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
67
+ args: tuple[Any],
68
+ kwargs: Optional[dict[str, Any]] = None,
69
+ to_dict: Optional[bool] = False,
70
+ ) -> Union[DynamicShapesSpec, dict[str, Any]]:
71
+ """
72
+ Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
73
+ Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
74
+ Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
75
+
76
+ dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
77
+ - Each tensor input is represented with a list of values, non-tensor inputs with None.
78
+ - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
79
+ - static dimensions are represented with ints.
80
+
81
+ dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
82
+
83
+ For example:
84
+ ```
85
+ dx = Dim("dx", min=4, max=16)
86
+ dy = dx + 1
87
+
88
+ inputs = (
89
+ [
90
+ torch.randn(4, 4),
91
+ torch.randn(5, 4),
92
+ ],
93
+ torch.randn(4),
94
+ torch.randn(4, 4),
95
+ "hello",
96
+ )
97
+ dynamic_shapes = {
98
+ "a": [
99
+ (dx, 4),
100
+ (dy, 4),
101
+ ],
102
+ "b": (Dim.STATIC,),
103
+ "c": None,
104
+ "d": None,
105
+ }
106
+ out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
107
+ ```
108
+ would generate the following output:
109
+ ```
110
+ {
111
+ "dynamic_shapes": (
112
+ [
113
+ ["dx", 4],
114
+ ["dx + 1", 4],
115
+ ],
116
+ ["_DimHint.STATIC"],
117
+ ["_DimHint.STATIC", "_DimHint.STATIC"],
118
+ None,
119
+ ),
120
+ "dims": {
121
+ "dx": {
122
+ "min": 4,
123
+ "max": 16,
124
+ "derived": ["dx + 1"],
125
+ },
126
+ },
127
+ }
128
+ ```
129
+ """
130
+ dims: dict[str, dict[str, Any]] = {}
131
+
132
+ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
133
+ """
134
+ Helps standardize the dynamic_shapes tree structure we serialize,
135
+ returning lists for each tensor shape, handling tensor-level Nones.
136
+ """
137
+ if not isinstance(tensor, torch.Tensor):
138
+ return None
139
+ if shape is None:
140
+ return [Dim.STATIC] * len(tensor.shape)
141
+
142
+ out = []
143
+ if isinstance(shape, dict):
144
+ for i, s in enumerate(tensor.shape):
145
+ out.append(s if shape.get(i) is None else shape.get(i))
146
+ else:
147
+ assert isinstance(shape, (tuple, list))
148
+ for i, s in enumerate(tensor.shape):
149
+ out.append(s if shape[i] is None else shape[i])
150
+ return out
151
+
152
+ def _track_dim_from_dims(
153
+ val: Union[None, int, _DimHint, Dim],
154
+ ) -> Union[None, int, str]:
155
+ """
156
+ Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
157
+ """
158
+ if val is None or isinstance(val, int): # non-tensor input or static
159
+ return val
160
+ if isinstance(val, _DimHint): # store enum as string
161
+ return val.__class__.__name__ + "." + val.type.name
162
+
163
+ assert isinstance(val, Dim)
164
+
165
+ # track root dim
166
+ root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
167
+ if root.__name__ not in dims:
168
+ dims[root.__name__] = {
169
+ "min": root.min, # type: ignore[attr-defined,union-attr]
170
+ "max": root.max, # type: ignore[attr-defined,union-attr]
171
+ "derived": set(),
172
+ }
173
+
174
+ # track derived dims
175
+ if isinstance(val, _DerivedDim):
176
+ dims[root.__name__]["derived"].add(val.__name__)
177
+
178
+ return val.__name__
179
+
180
+ if dynamic_shapes is None:
181
+ return {"dynamic_shapes": None, "dims": {}}
182
+
183
+ # convert to tuple of specs, for each arg/kwarg
184
+ kwargs = kwargs or {}
185
+ if isinstance(dynamic_shapes, dict):
186
+ dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
187
+ # pyrefly: ignore [bad-assignment, bad-argument-type]
188
+ dynamic_shapes = tuple(dynamic_shapes)
189
+ combined_args = tuple(args) + tuple(kwargs.values())
190
+
191
+ # run same check when we're processing shapes for export - is this too lazy?
192
+ _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type]
193
+
194
+ tree_shapes = _tree_map_with_path(
195
+ _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
196
+ )
197
+ serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
198
+ return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
199
+
200
+
201
+ def _load_dynamic_shapes(
202
+ spec: Union[DynamicShapesSpec, dict[str, Any]],
203
+ from_dict: Optional[bool] = False,
204
+ ) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
205
+ """
206
+ Utility function for dynamic shapes serialization.
207
+ Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
208
+ """
209
+ import sympy
210
+
211
+ from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
212
+
213
+ if from_dict:
214
+ if not isinstance(spec, dict):
215
+ raise UserError(
216
+ UserErrorType.INVALID_INPUT,
217
+ f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
218
+ )
219
+ if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
220
+ raise UserError(
221
+ UserErrorType.INVALID_INPUT,
222
+ "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
223
+ f"instead found {spec.keys()}",
224
+ )
225
+ dims = {}
226
+ for k, v in spec["dims"].items():
227
+ if not isinstance(k, str):
228
+ raise UserError(
229
+ UserErrorType.INVALID_INPUT,
230
+ f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
231
+ )
232
+ if sorted(v.keys()) != ["derived", "max", "min"]:
233
+ raise UserError(
234
+ UserErrorType.INVALID_INPUT,
235
+ f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
236
+ f"instead found {v.keys()}",
237
+ )
238
+ if not isinstance(v["min"], int):
239
+ raise UserError(
240
+ UserErrorType.INVALID_INPUT,
241
+ f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
242
+ )
243
+ if not isinstance(v["max"], int) or v["max"] is None:
244
+ raise UserError(
245
+ UserErrorType.INVALID_INPUT,
246
+ f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
247
+ )
248
+ if not isinstance(v["derived"], list) or any(
249
+ not isinstance(d, str) for d in v["derived"]
250
+ ):
251
+ raise UserError(
252
+ UserErrorType.INVALID_INPUT,
253
+ "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
254
+ f"got {k}: {v['derived']}",
255
+ )
256
+ dims[k] = RootDim(**v)
257
+ dynamic_shapes = spec["dynamic_shapes"]
258
+ else:
259
+ if not isinstance(spec, DynamicShapesSpec):
260
+ raise UserError(
261
+ UserErrorType.INVALID_INPUT,
262
+ f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
263
+ )
264
+ dims = spec.dims
265
+ dynamic_shapes = spec.dynamic_shapes
266
+
267
+ if dynamic_shapes is None:
268
+ return None
269
+
270
+ dim_cache = {}
271
+ for name, info in dims.items():
272
+ symbol = sympy.sympify(name)
273
+ if not isinstance(symbol, sympy.Symbol):
274
+ raise UserError(
275
+ UserErrorType.INVALID_INPUT,
276
+ f"Expected `spec['dims']` keys to be symbols, got {name}",
277
+ )
278
+ dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim
279
+ for _expr in info.derived:
280
+ expr = sympy.sympify(_expr)
281
+ if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
282
+ raise UserError(
283
+ UserErrorType.INVALID_INPUT,
284
+ f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
285
+ )
286
+ if not _is_supported_equivalence(expr):
287
+ raise UserError(
288
+ UserErrorType.INVALID_INPUT,
289
+ f"Expected derived expressions to be linear expressions, got {expr}",
290
+ )
291
+ modulus, remainder = sympy.polys.polytools.div(expr, symbol)
292
+ ddim = dim_cache[name]
293
+ if modulus != 1:
294
+ ddim = int(modulus) * ddim # type: ignore[assignment, operator]
295
+ if remainder != 0:
296
+ ddim = ddim + int(remainder) # type: ignore[assignment, operator]
297
+ dim_cache[_expr] = ddim # cache derived dims
298
+
299
+ def deserialize_shape(
300
+ val: Union[None, int, str],
301
+ ) -> Union[None, int, Dim, _DimHint]:
302
+ if val is None or isinstance(val, int):
303
+ return val
304
+ elif val == "_DimHint.AUTO":
305
+ return _DimHint.AUTO()
306
+ elif val == "_DimHint.DYNAMIC":
307
+ return _DimHint.DYNAMIC()
308
+ elif val == "_DimHint.STATIC":
309
+ return _DimHint.STATIC()
310
+ if not isinstance(val, str):
311
+ raise UserError(
312
+ UserErrorType.INVALID_INPUT,
313
+ "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
314
+ f" or derived expressions, got {val}",
315
+ )
316
+ if val not in dim_cache:
317
+ raise UserError(
318
+ UserErrorType.INVALID_INPUT,
319
+ "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
320
+ f"got {val} which is not in {dims.keys()}",
321
+ )
322
+ return dim_cache[val] # type: ignore[return-value]
323
+
324
+ return tree_map(deserialize_shape, dynamic_shapes)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // @generated by update_schema.py
2
+ // checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>>
3
+
4
+ namespace py3 torch._export
5
+ namespace cpp2 torch._export.schema
6
+
7
+ enum ArgumentKind {
8
+ UNKNOWN = 0,
9
+ POSITIONAL = 1,
10
+ KEYWORD = 2,
11
+ }
12
+
13
+
14
+ enum Layout {
15
+ Unknown = 0,
16
+ SparseCoo = 1,
17
+ SparseCsr = 2,
18
+ SparseCsc = 3,
19
+ SparseBsr = 4,
20
+ SparseBsc = 5,
21
+ _mkldnn = 6,
22
+ Strided = 7,
23
+ }
24
+
25
+
26
+ enum MemoryFormat {
27
+ Unknown = 0,
28
+ ContiguousFormat = 1,
29
+ ChannelsLast = 2,
30
+ ChannelsLast3d = 3,
31
+ PreserveFormat = 4,
32
+ }
33
+
34
+
35
+ enum ScalarType {
36
+ UNKNOWN = 0,
37
+ BYTE = 1,
38
+ CHAR = 2,
39
+ SHORT = 3,
40
+ INT = 4,
41
+ LONG = 5,
42
+ HALF = 6,
43
+ FLOAT = 7,
44
+ DOUBLE = 8,
45
+ COMPLEXHALF = 9,
46
+ COMPLEXFLOAT = 10,
47
+ COMPLEXDOUBLE = 11,
48
+ BOOL = 12,
49
+ BFLOAT16 = 13,
50
+ UINT16 = 28,
51
+ FLOAT8E4M3FN = 29,
52
+ FLOAT8E5M2 = 30,
53
+ FLOAT8E4M3FNUZ = 31,
54
+ FLOAT8E5M2FNUZ = 32,
55
+ }
56
+
57
+
58
+ struct Device {
59
+ 10: string type;
60
+ 20: optional i64 index;
61
+ }
62
+
63
+ union SymExprHint {
64
+ 10: i64 as_int;
65
+ 20: bool as_bool;
66
+ 30: double as_float;
67
+ }
68
+
69
+ struct SymExpr {
70
+ 10: string expr_str;
71
+ 20: optional SymExprHint hint;
72
+ }
73
+
74
+ union SymInt {
75
+ 10: SymExpr as_expr;
76
+ 20: i64 as_int;
77
+ }
78
+
79
+ union SymFloat {
80
+ 10: SymExpr as_expr;
81
+ 20: double as_float;
82
+ }
83
+
84
+ union SymBool {
85
+ 10: SymExpr as_expr;
86
+ 20: bool as_bool;
87
+ }
88
+
89
+ struct TensorMeta {
90
+ 10: ScalarType dtype;
91
+ 20: list<SymInt> sizes;
92
+ 30: bool requires_grad;
93
+ 40: Device device;
94
+ 50: list<SymInt> strides;
95
+ 60: SymInt storage_offset;
96
+ 70: Layout layout;
97
+ }
98
+
99
+ union SymIntArgument {
100
+ 10: string as_name;
101
+ 20: i64 as_int;
102
+ }
103
+
104
+ union SymFloatArgument {
105
+ 10: string as_name;
106
+ 20: double as_float;
107
+ }
108
+
109
+ union SymBoolArgument {
110
+ 10: string as_name;
111
+ 20: bool as_bool;
112
+ }
113
+
114
+ struct TensorArgument {
115
+ 10: string name;
116
+ }
117
+
118
+ struct TokenArgument {
119
+ 10: string name;
120
+ }
121
+
122
+ union OptionalTensorArgument {
123
+ 20: TensorArgument as_tensor;
124
+ 10: bool as_none;
125
+ }
126
+
127
+ struct GraphArgument {
128
+ 10: string name;
129
+ 20: Graph graph;
130
+ }
131
+
132
+ struct CustomObjArgument {
133
+ 10: string name;
134
+ 20: string class_fqn;
135
+ }
136
+
137
+ struct ComplexValue {
138
+ 10: double real;
139
+ 20: double imag;
140
+ }
141
+
142
+ union Argument {
143
+ 10: bool as_none;
144
+ 20: TensorArgument as_tensor;
145
+ 30: list<TensorArgument> as_tensors;
146
+ 50: i64 as_int;
147
+ 70: list<i64> as_ints;
148
+ 80: double as_float;
149
+ 90: list<double> as_floats;
150
+ 100: string as_string;
151
+ 101: list<string> as_strings;
152
+ 110: SymIntArgument as_sym_int;
153
+ 120: list<SymIntArgument> as_sym_ints;
154
+ 130: ScalarType as_scalar_type;
155
+ 140: MemoryFormat as_memory_format;
156
+ 150: Layout as_layout;
157
+ 160: Device as_device;
158
+ 170: bool as_bool;
159
+ 180: list<bool> as_bools;
160
+ 182: SymBoolArgument as_sym_bool;
161
+ 184: list<SymBoolArgument> as_sym_bools;
162
+ 200: GraphArgument as_graph;
163
+ 190: list<OptionalTensorArgument> as_optional_tensors;
164
+ 210: CustomObjArgument as_custom_obj;
165
+ 220: string as_operator;
166
+ 230: SymFloatArgument as_sym_float;
167
+ 240: list<SymFloatArgument> as_sym_floats;
168
+ 250: OptionalTensorArgument as_optional_tensor;
169
+ 260: ComplexValue as_complex;
170
+ 280: list<list<i64>> as_int_lists;
171
+ 290: map<string, Argument> as_string_to_argument;
172
+ }
173
+
174
+ struct NamedArgument {
175
+ 10: string name;
176
+ 20: Argument arg;
177
+ 30: optional ArgumentKind kind;
178
+ }
179
+
180
+ struct Node {
181
+ 10: string target;
182
+ 20: list<NamedArgument> inputs;
183
+ 30: list<Argument> outputs;
184
+ 40: map<string, string> metadata;
185
+ 50: optional bool is_hop_single_tensor_return;
186
+ }
187
+
188
+ struct Graph {
189
+ 10: list<Argument> inputs;
190
+ 20: list<Argument> outputs;
191
+ 30: list<Node> nodes;
192
+ 40: map<string, TensorMeta> tensor_values;
193
+ 50: map<string, SymInt> sym_int_values;
194
+ 60: map<string, SymBool> sym_bool_values;
195
+ 70: bool is_single_tensor_return;
196
+ 80: map<string, CustomObjArgument> custom_obj_values;
197
+ 90: map<string, SymFloat> sym_float_values;
198
+ }
199
+
200
+ struct UserInputSpec {
201
+ 10: Argument arg;
202
+ }
203
+
204
+ union ConstantValue {
205
+ 10: bool as_none;
206
+ 20: i64 as_int;
207
+ 30: double as_float;
208
+ 40: string as_string;
209
+ 50: bool as_bool;
210
+ }
211
+
212
+ struct InputToConstantInputSpec {
213
+ 10: string name;
214
+ 20: ConstantValue value;
215
+ }
216
+
217
+ struct InputToParameterSpec {
218
+ 10: TensorArgument arg;
219
+ 20: string parameter_name;
220
+ }
221
+
222
+ struct InputToBufferSpec {
223
+ 10: TensorArgument arg;
224
+ 20: string buffer_name;
225
+ 30: bool persistent;
226
+ }
227
+
228
+ struct InputToTensorConstantSpec {
229
+ 10: TensorArgument arg;
230
+ 20: string tensor_constant_name;
231
+ }
232
+
233
+ struct InputToCustomObjSpec {
234
+ 10: CustomObjArgument arg;
235
+ 20: string custom_obj_name;
236
+ }
237
+
238
+ struct InputTokenSpec {
239
+ 10: TokenArgument arg;
240
+ }
241
+
242
+ union InputSpec {
243
+ 10: UserInputSpec user_input;
244
+ 20: InputToParameterSpec parameter;
245
+ 30: InputToBufferSpec buffer;
246
+ 40: InputToTensorConstantSpec tensor_constant;
247
+ 50: InputToCustomObjSpec custom_obj;
248
+ 70: InputTokenSpec token;
249
+ 60: InputToConstantInputSpec constant_input;
250
+ }
251
+
252
+ struct UserOutputSpec {
253
+ 10: Argument arg;
254
+ }
255
+
256
+ struct LossOutputSpec {
257
+ 10: TensorArgument arg;
258
+ }
259
+
260
+ struct BufferMutationSpec {
261
+ 10: TensorArgument arg;
262
+ 20: string buffer_name;
263
+ }
264
+
265
+ struct ParameterMutationSpec {
266
+ 10: TensorArgument arg;
267
+ 20: string parameter_name;
268
+ }
269
+
270
+ struct GradientToParameterSpec {
271
+ 10: TensorArgument arg;
272
+ 20: string parameter_name;
273
+ }
274
+
275
+ struct GradientToUserInputSpec {
276
+ 10: TensorArgument arg;
277
+ 20: string user_input_name;
278
+ }
279
+
280
+ struct UserInputMutationSpec {
281
+ 10: TensorArgument arg;
282
+ 20: string user_input_name;
283
+ }
284
+
285
+ struct OutputTokenSpec {
286
+ 10: TokenArgument arg;
287
+ }
288
+
289
+ union OutputSpec {
290
+ 10: UserOutputSpec user_output;
291
+ 20: LossOutputSpec loss_output;
292
+ 30: BufferMutationSpec buffer_mutation;
293
+ 40: GradientToParameterSpec gradient_to_parameter;
294
+ 50: GradientToUserInputSpec gradient_to_user_input;
295
+ 60: UserInputMutationSpec user_input_mutation;
296
+ 70: OutputTokenSpec token;
297
+ 80: ParameterMutationSpec parameter_mutation;
298
+ }
299
+
300
+ struct GraphSignature {
301
+ 10: list<InputSpec> input_specs;
302
+ 20: list<OutputSpec> output_specs;
303
+ }
304
+
305
+ struct RangeConstraint {
306
+ 10: optional i64 min_val;
307
+ 20: optional i64 max_val;
308
+ }
309
+
310
+ struct ModuleCallSignature {
311
+ 10: list<Argument> inputs;
312
+ 20: list<Argument> outputs;
313
+ 30: string in_spec;
314
+ 40: string out_spec;
315
+ 50: optional list<string> forward_arg_names;
316
+ }
317
+
318
+ struct ModuleCallEntry {
319
+ 10: string fqn;
320
+ 30: optional ModuleCallSignature signature;
321
+ }
322
+
323
+ struct NamedTupleDef {
324
+ 10: list<string> field_names;
325
+ }
326
+
327
+ struct GraphModule {
328
+ 10: Graph graph;
329
+ 50: GraphSignature signature;
330
+ 60: list<ModuleCallEntry> module_call_graph;
331
+ 40: map<string, string> metadata;
332
+ 70: map<string, NamedTupleDef> treespec_namedtuple_fields;
333
+ }
334
+
335
+ struct SchemaVersion {
336
+ 10: i64 major;
337
+ 20: i64 minor;
338
+ }
339
+
340
+ struct ExportedProgram {
341
+ 10: GraphModule graph_module;
342
+ 20: map<string, i64> opset_version;
343
+ 30: map<string, RangeConstraint> range_constraints;
344
+ 60: SchemaVersion schema_version;
345
+ 70: list<string> verifiers;
346
+ 80: string torch_version;
347
+ 90: list<string> guards_code;
348
+ }
349
+
350
+ struct PayloadMeta {
351
+ 10: string path_name;
352
+ 20: bool is_param;
353
+ 30: bool use_pickle;
354
+ 40: optional TensorMeta tensor_meta;
355
+ }
356
+
357
+ struct PayloadConfig {
358
+ 10: map<string, PayloadMeta> config;
359
+ }
360
+
361
+ struct AOTInductorModelPickleData {
362
+ 1: string library_basename;
363
+ 2: list<string> input_names;
364
+ 3: list<string> output_names;
365
+ 4: optional i64 floating_point_input_dtype;
366
+ 5: optional i64 floating_point_output_dtype;
367
+ 6: optional bool aot_inductor_model_is_cpu;
368
+ }
369
+
370
+ struct ExternKernelNode {
371
+ 10: string name;
372
+ 20: Node node;
373
+ }
374
+
375
+ struct ExternKernelNodes {
376
+ 10: list<ExternKernelNode> nodes;
377
+ }