Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py +299 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py +621 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py +39 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py +179 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py +55 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py +5 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py +175 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py +21 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py +47 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py +32 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py +45 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py +1 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py +111 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +254 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py +146 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py +304 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py +99 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py +80 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py +417 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py +36 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py +189 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +676 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py +121 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py +65 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py +190 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py +0 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py +324 -0
- Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift +377 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/cudagraphs.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module implements CUDA graphs support for TorchDynamo backends.
|
| 3 |
+
|
| 4 |
+
CUDA graphs allow for capturing and replaying GPU operations, which can significantly
|
| 5 |
+
reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
|
| 6 |
+
|
| 7 |
+
- CUDA graph creation and management for both forward and backward passes
|
| 8 |
+
- Input mutation detection and handling
|
| 9 |
+
- Device compatibility checking
|
| 10 |
+
- Stack trace management for debugging
|
| 11 |
+
- Integration with TorchInductor's cudagraph trees
|
| 12 |
+
|
| 13 |
+
The backend supports two main modes:
|
| 14 |
+
1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
|
| 15 |
+
2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
|
| 16 |
+
|
| 17 |
+
Key components:
|
| 18 |
+
- CudagraphsBackend: Main backend class for CUDA graph integration
|
| 19 |
+
- Mutation detection utilities to ensure graph safety
|
| 20 |
+
- Device mapping and compatibility checks
|
| 21 |
+
- Stack trace collection for debugging
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import functools
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
from collections.abc import Callable, Sequence
|
| 27 |
+
from typing import Any, Optional
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.fx
|
| 31 |
+
from torch._dynamo import config
|
| 32 |
+
from torch._dynamo.backends.common import aot_autograd
|
| 33 |
+
from torch._dynamo.backends.debugging import boxed_nop
|
| 34 |
+
from torch._inductor.cudagraph_utils import (
|
| 35 |
+
BoxedDeviceIndex,
|
| 36 |
+
check_multiple_devices_or_any_cpu_nodes,
|
| 37 |
+
format_default_skip_message,
|
| 38 |
+
get_mutation_stack_trace,
|
| 39 |
+
get_placeholder_info,
|
| 40 |
+
log_cudagraph_skip_and_bump_counter,
|
| 41 |
+
)
|
| 42 |
+
from torch._inductor.utils import (
|
| 43 |
+
BoxedBool,
|
| 44 |
+
count_tangents,
|
| 45 |
+
get_first_incompatible_cudagraph_node,
|
| 46 |
+
num_fw_fixed_arguments,
|
| 47 |
+
output_node,
|
| 48 |
+
)
|
| 49 |
+
from torch.multiprocessing.reductions import StorageWeakRef
|
| 50 |
+
|
| 51 |
+
from .registry import register_backend
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def find_input_mutations(g: torch.fx.Graph) -> set[int]:
|
| 55 |
+
def meta_fk(meta: dict[str, Any]) -> Any:
|
| 56 |
+
return meta["val"] if "val" in meta else meta["fake_result"]
|
| 57 |
+
|
| 58 |
+
inputs = defaultdict(set)
|
| 59 |
+
input_idx = 0
|
| 60 |
+
mutated_inputs = set()
|
| 61 |
+
for n in g.nodes:
|
| 62 |
+
if n.op == "placeholder":
|
| 63 |
+
if isinstance(meta_fk(n.meta), torch.Tensor):
|
| 64 |
+
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
|
| 65 |
+
input_idx += 1
|
| 66 |
+
elif n.op == "call_function":
|
| 67 |
+
if not hasattr(n.target, "_schema"):
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
schema = n.target._schema
|
| 71 |
+
for i, arg in enumerate(schema.arguments):
|
| 72 |
+
if i < len(n.args):
|
| 73 |
+
argument = n.args[i]
|
| 74 |
+
else:
|
| 75 |
+
if arg.name not in n.kwargs:
|
| 76 |
+
continue
|
| 77 |
+
argument = n.kwargs[arg.name]
|
| 78 |
+
mut_arg = False
|
| 79 |
+
if arg.alias_info:
|
| 80 |
+
if arg.alias_info.is_write:
|
| 81 |
+
mut_arg = True
|
| 82 |
+
if mut_arg:
|
| 83 |
+
# TODO: not correct for args that contain tensors in a struct
|
| 84 |
+
# like list
|
| 85 |
+
mutated_inputs |= inputs[
|
| 86 |
+
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
# TODO: error on unrecognized nodes
|
| 90 |
+
return mutated_inputs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_device_node_mapping(
|
| 94 |
+
gm: torch.fx.GraphModule,
|
| 95 |
+
) -> dict[torch.device, torch.fx.Node]:
|
| 96 |
+
device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
| 97 |
+
for n in gm.graph.nodes:
|
| 98 |
+
t = n.meta.get("val", None)
|
| 99 |
+
if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
|
| 100 |
+
device_node_mapping[t.device] = n
|
| 101 |
+
return device_node_mapping
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def check_for_mutation_ignore_cuda_graph_managed_tensor(
|
| 105 |
+
aot_model: torch.fx.GraphModule, num_fixed: int
|
| 106 |
+
) -> Optional[str]:
|
| 107 |
+
mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
|
| 108 |
+
if not mutation_indices:
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
placeholders = get_placeholder_info(aot_model.graph)
|
| 112 |
+
return get_mutation_stack_trace(placeholders, mutation_indices)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed: int) -> Optional[str]:
|
| 116 |
+
if not config.cudagraph_backend_support_input_mutation:
|
| 117 |
+
if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
|
| 118 |
+
aot_model, num_fixed
|
| 119 |
+
):
|
| 120 |
+
return mut_skip
|
| 121 |
+
|
| 122 |
+
if skip := check_multiple_devices_or_any_cpu_nodes(
|
| 123 |
+
get_device_node_mapping(aot_model)
|
| 124 |
+
):
|
| 125 |
+
return skip
|
| 126 |
+
|
| 127 |
+
if node := get_first_incompatible_cudagraph_node(aot_model):
|
| 128 |
+
return format_default_skip_message(f"incompatible op ({node.name})")
|
| 129 |
+
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_device_index(gm: torch.fx.GraphModule) -> int:
|
| 134 |
+
device = next(iter(get_device_node_mapping(gm)))
|
| 135 |
+
assert device.type == "cuda"
|
| 136 |
+
return device.index
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_stack_traces(gm: torch.fx.GraphModule) -> list[Optional[str]]:
|
| 140 |
+
output = output_node(gm)
|
| 141 |
+
assert len(output.args) == 1
|
| 142 |
+
args = output.args[0]
|
| 143 |
+
if not hasattr(args, "__iter__"):
|
| 144 |
+
return []
|
| 145 |
+
return [
|
| 146 |
+
(arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
|
| 147 |
+
for arg in args # type: ignore[union-attr]
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def cudagraphs(dynamo_model: torch.fx.GraphModule, dynamo_inputs: Sequence[Any]) -> Any:
|
| 152 |
+
from torch._inductor.cudagraph_trees import cudagraphify_impl
|
| 153 |
+
|
| 154 |
+
do_cudagraphs = BoxedBool(True)
|
| 155 |
+
boxed_device_index = BoxedDeviceIndex(None)
|
| 156 |
+
|
| 157 |
+
def forward_cudagraphs(
|
| 158 |
+
aot_model: torch.fx.GraphModule,
|
| 159 |
+
aot_inputs: list[Any],
|
| 160 |
+
is_inference: bool = False,
|
| 161 |
+
) -> Any:
|
| 162 |
+
interp = boxed_nop(aot_model, aot_inputs)
|
| 163 |
+
fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
|
| 164 |
+
if skip_msg := check_for_skip(aot_model, fixed):
|
| 165 |
+
BoxedBool.disable(do_cudagraphs)
|
| 166 |
+
log_cudagraph_skip_and_bump_counter(
|
| 167 |
+
f"skipping cudagraphs due to {skip_msg}"
|
| 168 |
+
)
|
| 169 |
+
return interp
|
| 170 |
+
|
| 171 |
+
boxed_device_index.set(get_device_index(aot_model))
|
| 172 |
+
out = cudagraphify_impl(
|
| 173 |
+
interp,
|
| 174 |
+
aot_inputs,
|
| 175 |
+
range(fixed),
|
| 176 |
+
device_index=boxed_device_index.value,
|
| 177 |
+
is_backward=False,
|
| 178 |
+
is_inference=False, # Q: should forward is_inference here?
|
| 179 |
+
stack_traces=get_stack_traces(aot_model),
|
| 180 |
+
placeholders=get_placeholder_info(aot_model.graph),
|
| 181 |
+
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
| 182 |
+
)
|
| 183 |
+
out._boxed_call = True # type: ignore[attr-defined]
|
| 184 |
+
return out
|
| 185 |
+
|
| 186 |
+
def backward_cudagraphs(
|
| 187 |
+
aot_model: torch.fx.GraphModule, aot_inputs: list[Any]
|
| 188 |
+
) -> Any:
|
| 189 |
+
interp = boxed_nop(aot_model, aot_inputs)
|
| 190 |
+
if not do_cudagraphs:
|
| 191 |
+
return aot_model
|
| 192 |
+
|
| 193 |
+
fixed = count_tangents(aot_model)
|
| 194 |
+
if skip_msg := check_for_skip(aot_model, fixed):
|
| 195 |
+
log_cudagraph_skip_and_bump_counter(
|
| 196 |
+
f"skipping cudagraphs due to {skip_msg}"
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# See [Backward Generation Handling]
|
| 200 |
+
device_idx = boxed_device_index.value
|
| 201 |
+
if device_idx is None:
|
| 202 |
+
device_idx = 0 # Default to device 0 if not set
|
| 203 |
+
manager = torch._inductor.cudagraph_trees.get_manager(
|
| 204 |
+
device_idx, create_if_none_exists=False
|
| 205 |
+
)
|
| 206 |
+
assert manager is not None
|
| 207 |
+
|
| 208 |
+
def fn(inputs: list[Any]) -> Any:
|
| 209 |
+
# pyrefly: ignore [missing-attribute]
|
| 210 |
+
manager.set_to_running_backward()
|
| 211 |
+
return aot_model(inputs)
|
| 212 |
+
|
| 213 |
+
fn._boxed_call = True # type: ignore[attr-defined]
|
| 214 |
+
return fn
|
| 215 |
+
|
| 216 |
+
out = cudagraphify_impl(
|
| 217 |
+
interp,
|
| 218 |
+
aot_inputs,
|
| 219 |
+
range(fixed),
|
| 220 |
+
device_index=get_device_index(aot_model),
|
| 221 |
+
is_backward=True,
|
| 222 |
+
is_inference=False,
|
| 223 |
+
stack_traces=get_stack_traces(aot_model),
|
| 224 |
+
placeholders=get_placeholder_info(aot_model.graph),
|
| 225 |
+
mutated_input_idxs=find_input_mutations(aot_model.graph),
|
| 226 |
+
)
|
| 227 |
+
out._boxed_call = True # type: ignore[attr-defined]
|
| 228 |
+
return out
|
| 229 |
+
|
| 230 |
+
aot_cudagraphs = aot_autograd(
|
| 231 |
+
fw_compiler=forward_cudagraphs,
|
| 232 |
+
bw_compiler=backward_cudagraphs,
|
| 233 |
+
inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
|
| 234 |
+
keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
|
| 235 |
+
)
|
| 236 |
+
return aot_cudagraphs(dynamo_model, dynamo_inputs)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CudagraphsBackend:
|
| 240 |
+
compiler_name = "cudagraphs"
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def reset() -> None:
|
| 244 |
+
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
| 245 |
+
|
| 246 |
+
reset_cudagraph_trees()
|
| 247 |
+
|
| 248 |
+
@staticmethod
|
| 249 |
+
def __call__(model: torch.fx.GraphModule, inputs: Sequence[Any]) -> Any:
|
| 250 |
+
return cudagraphs(model, inputs)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
|
| 254 |
+
# for debugging and can serve as a perf baseline.
|
| 255 |
+
register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def cudagraphs_inner(
|
| 259 |
+
model: Callable[..., Any],
|
| 260 |
+
inputs: Sequence[Any],
|
| 261 |
+
copy_outputs: bool = True,
|
| 262 |
+
copy_inputs: bool = True,
|
| 263 |
+
) -> Callable[..., Sequence[Any]]:
|
| 264 |
+
"""This isn't registered as a backend, but is used in some benchmarks"""
|
| 265 |
+
assert isinstance(inputs, (list, tuple))
|
| 266 |
+
if copy_inputs:
|
| 267 |
+
static_inputs = [torch.zeros_like(x) for x in inputs]
|
| 268 |
+
else:
|
| 269 |
+
static_inputs = list(inputs)
|
| 270 |
+
|
| 271 |
+
# warmup
|
| 272 |
+
torch.cuda.synchronize()
|
| 273 |
+
stream = torch.cuda.Stream()
|
| 274 |
+
stream.wait_stream(torch.cuda.current_stream())
|
| 275 |
+
with torch.cuda.stream(stream):
|
| 276 |
+
model(*inputs)
|
| 277 |
+
stream.synchronize()
|
| 278 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 279 |
+
torch.cuda.synchronize()
|
| 280 |
+
|
| 281 |
+
# record
|
| 282 |
+
graph = torch.cuda.CUDAGraph()
|
| 283 |
+
with torch.cuda.graph(graph, stream=stream):
|
| 284 |
+
static_outputs = model(*static_inputs)
|
| 285 |
+
if not isinstance(static_outputs, (list, tuple)):
|
| 286 |
+
static_outputs = (static_outputs,)
|
| 287 |
+
|
| 288 |
+
def run(*new_inputs: Any) -> Sequence[Any]:
|
| 289 |
+
assert len(static_inputs) == len(new_inputs)
|
| 290 |
+
if copy_inputs:
|
| 291 |
+
for dst, src in zip(static_inputs, new_inputs):
|
| 292 |
+
dst.copy_(src)
|
| 293 |
+
graph.replay()
|
| 294 |
+
if copy_outputs:
|
| 295 |
+
return [x.clone() for x in static_outputs]
|
| 296 |
+
else:
|
| 297 |
+
return static_outputs
|
| 298 |
+
|
| 299 |
+
return run
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/distributed.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module implements distributed training optimizations for TorchDynamo backends.
|
| 3 |
+
|
| 4 |
+
It provides functionality to optimize models wrapped in DistributedDataParallel (DDP)
|
| 5 |
+
by intelligently splitting compiled graphs to align with DDP's gradient synchronization
|
| 6 |
+
boundaries. Key features include:
|
| 7 |
+
|
| 8 |
+
- Graph partitioning based on parameter bucket sizes
|
| 9 |
+
- Optimization of allreduce operations for distributed training
|
| 10 |
+
- Support for parameter ignoring and buffer handling
|
| 11 |
+
- Submodule compilation and management
|
| 12 |
+
- Debugging utilities for distributed training
|
| 13 |
+
|
| 14 |
+
The main component is the DDPOptimizer class, which handles graph splitting and
|
| 15 |
+
recompilation to enable efficient distributed training while maintaining the benefits
|
| 16 |
+
of compilation.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import traceback
|
| 21 |
+
from collections.abc import Callable
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import Any, Optional, TYPE_CHECKING
|
| 24 |
+
from unittest import mock
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from torch import fx
|
| 28 |
+
from torch._dynamo.backends.registry import CompiledFn, CompilerFn
|
| 29 |
+
from torch._dynamo.output_graph import GraphCompileReason
|
| 30 |
+
from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
|
| 31 |
+
from torch._logging import trace_structured
|
| 32 |
+
from torch.fx.node import Node
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if TYPE_CHECKING:
|
| 36 |
+
from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Regular log messages should go through 'log'.
|
| 40 |
+
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
|
| 41 |
+
# See docs/source/logging.rst for more info.
|
| 42 |
+
log = logging.getLogger(__name__)
|
| 43 |
+
ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def args_str(args: Any) -> str:
|
| 47 |
+
# a debug helper
|
| 48 |
+
if torch.is_tensor(args):
|
| 49 |
+
return f"T[{args.shape}]"
|
| 50 |
+
elif isinstance(args, tuple):
|
| 51 |
+
return f"tuple({', '.join([args_str(x) for x in args])})"
|
| 52 |
+
elif isinstance(args, list):
|
| 53 |
+
return f"list({', '.join([args_str(x) for x in args])})"
|
| 54 |
+
else:
|
| 55 |
+
return str(args)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class Bucket:
|
| 60 |
+
size: int = 0
|
| 61 |
+
params: list[str] = field(default_factory=list)
|
| 62 |
+
nodes: list[fx.Node] = field(default_factory=list)
|
| 63 |
+
|
| 64 |
+
# param_ids is just used for unit testing
|
| 65 |
+
param_ids: list[int] = field(default_factory=list)
|
| 66 |
+
|
| 67 |
+
# keep track of any buckets that were extended for logging purposes
|
| 68 |
+
opcount_increased_to_capture_external_output: int = 0
|
| 69 |
+
paramsize_before_opcount_increase: int = 0
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def bucket_has_external_output(bucket: Bucket) -> bool:
|
| 73 |
+
nodes_in_bucket = set()
|
| 74 |
+
# we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
|
| 75 |
+
# so we don't reverse it here
|
| 76 |
+
for node in bucket.nodes:
|
| 77 |
+
# assume node.op != output, since those are filtered in the original iteration
|
| 78 |
+
nodes_in_bucket.add(node)
|
| 79 |
+
for user in node.users:
|
| 80 |
+
if user not in nodes_in_bucket:
|
| 81 |
+
return True
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
|
| 86 |
+
headers = ("Index", "Size (b)", "Param Names")
|
| 87 |
+
rows: list[tuple[Optional[int], Optional[int], str]] = []
|
| 88 |
+
extended_buckets = []
|
| 89 |
+
for idx, bucket in enumerate(reversed(buckets)):
|
| 90 |
+
if len(bucket.params) > 0:
|
| 91 |
+
rows.append((idx, bucket.size, bucket.params[0]))
|
| 92 |
+
rows.extend((None, None, param) for param in bucket.params[1:])
|
| 93 |
+
if bucket.opcount_increased_to_capture_external_output > 0:
|
| 94 |
+
extended_buckets.append(
|
| 95 |
+
(
|
| 96 |
+
idx,
|
| 97 |
+
bucket.opcount_increased_to_capture_external_output,
|
| 98 |
+
bucket.size - bucket.paramsize_before_opcount_increase,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if rows:
|
| 103 |
+
log.info(
|
| 104 |
+
"\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
|
| 105 |
+
bucket_bytes_cap,
|
| 106 |
+
len(buckets),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if extended_buckets:
|
| 110 |
+
log.warning(
|
| 111 |
+
"Some buckets were extended beyond their requested parameter capacities"
|
| 112 |
+
" in order to ensure each subgraph has an output node, required for fx graph partitioning."
|
| 113 |
+
" This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
|
| 114 |
+
" and returning no logical outputs. This should not be a problem, unless it results in too few graph"
|
| 115 |
+
" partitions for optimal DDP performance."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
from tabulate import tabulate
|
| 120 |
+
|
| 121 |
+
log.debug(
|
| 122 |
+
"\nDDPOptimizer produced the following bucket assignments:\n%s",
|
| 123 |
+
tabulate(rows, headers=headers, tablefmt="simple_grid"),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if extended_buckets:
|
| 127 |
+
log.warning(
|
| 128 |
+
"DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
|
| 129 |
+
tabulate(
|
| 130 |
+
extended_buckets,
|
| 131 |
+
headers=("Index", "Extra Ops", "Extra Param Size (b)"),
|
| 132 |
+
tablefmt="simple_grid",
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
except ImportError:
|
| 136 |
+
log.debug(
|
| 137 |
+
"Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
log.debug("DDPOptimizer captured no parameters and did not split this graph.")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def has_higher_order_op(gm: fx.GraphModule) -> bool:
|
| 144 |
+
# Check if there is a higher order op in the graph
|
| 145 |
+
for node in gm.graph.nodes:
|
| 146 |
+
if node.op == "get_attr":
|
| 147 |
+
maybe_param = getattr(gm, node.target)
|
| 148 |
+
if isinstance(maybe_param, torch.fx.GraphModule):
|
| 149 |
+
return True
|
| 150 |
+
return False
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
| 154 |
+
for name, module in split_gm.named_modules():
|
| 155 |
+
if "." not in name and len(name):
|
| 156 |
+
# TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
|
| 157 |
+
module.meta = orig_gm.meta
|
| 158 |
+
module._param_name_to_source = orig_gm._param_name_to_source
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
|
| 162 |
+
name_to_dynamo_source = {}
|
| 163 |
+
for node in orig_gm.graph.find_nodes(op="placeholder"):
|
| 164 |
+
name_to_dynamo_source[node.name] = node._dynamo_source
|
| 165 |
+
|
| 166 |
+
for name, module in split_gm.named_modules():
|
| 167 |
+
if "." not in name and len(name):
|
| 168 |
+
for node in module.graph.find_nodes(op="placeholder"):
|
| 169 |
+
# non-placeholder in original_gm may become placeholder in submodules
|
| 170 |
+
node._dynamo_source = name_to_dynamo_source.get(node.name, None)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class DDPOptimizerContext:
|
| 174 |
+
def __init__(self) -> None:
|
| 175 |
+
self.curr_bucket: int = -1
|
| 176 |
+
self.metadata_per_bucket: list[ViewAndMutationMeta] = []
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# compile each of the partitioned submodules using the user-provided compiler
|
| 180 |
+
class SubmodCompiler(torch.fx.interpreter.Interpreter):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
module: fx.GraphModule,
|
| 184 |
+
compiler: CompilerFn,
|
| 185 |
+
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
| 186 |
+
) -> None:
|
| 187 |
+
super().__init__(module)
|
| 188 |
+
self.compiler = compiler
|
| 189 |
+
self.fake_mode = fake_mode
|
| 190 |
+
# See Note [DDPOptimizer and fw_metadata]
|
| 191 |
+
ctx = torch._guards.TracingContext.try_get()
|
| 192 |
+
if ctx is not None:
|
| 193 |
+
ctx.ddp_optimizer_ctx = DDPOptimizerContext()
|
| 194 |
+
|
| 195 |
+
def compile_submod(
|
| 196 |
+
self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
|
| 197 |
+
) -> Any:
|
| 198 |
+
"""
|
| 199 |
+
Compile the submodule,
|
| 200 |
+
using a wrapper to make sure its output is always a tuple,
|
| 201 |
+
which is required by AotAutograd based compilers
|
| 202 |
+
"""
|
| 203 |
+
assert len(kwargs) == 0, "We assume only args for these modules"
|
| 204 |
+
|
| 205 |
+
class WrapperModule(torch.nn.Module):
|
| 206 |
+
def __init__(
|
| 207 |
+
self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
|
| 208 |
+
) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.submod = submod
|
| 211 |
+
self.unwrap_singleton_tuple = unwrap_singleton_tuple
|
| 212 |
+
|
| 213 |
+
def forward(self, *args: Any) -> Any:
|
| 214 |
+
x = self.submod(*args)
|
| 215 |
+
# TODO(whc)
|
| 216 |
+
# for some reason the isinstance check is necessary if I split one node per submod
|
| 217 |
+
# - even though I supposedly wrapped the output in a tuple in those cases, the real
|
| 218 |
+
# compiled module was still returning a tensor
|
| 219 |
+
if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
|
| 220 |
+
return x[0]
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
unwrap_singleton_tuple = False
|
| 224 |
+
for sn in input_mod.graph.nodes:
|
| 225 |
+
if sn.op == "output":
|
| 226 |
+
if not isinstance(sn.args[0], tuple):
|
| 227 |
+
unwrap_singleton_tuple = True
|
| 228 |
+
sn.args = (sn.args,)
|
| 229 |
+
|
| 230 |
+
input_mod.recompile()
|
| 231 |
+
input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
|
| 232 |
+
"DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
|
| 233 |
+
" Set `torch._dynamo.config.optimize_ddp = False` to disable.",
|
| 234 |
+
[
|
| 235 |
+
# it's close to useless to get a real stacktrace here, and quite verbose.
|
| 236 |
+
traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
|
| 237 |
+
],
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
wrapper = WrapperModule(
|
| 241 |
+
self.compiler(input_mod, args),
|
| 242 |
+
unwrap_singleton_tuple,
|
| 243 |
+
)
|
| 244 |
+
return wrapper
|
| 245 |
+
|
| 246 |
+
# Note:
|
| 247 |
+
#
|
| 248 |
+
# The way distributed works today around fake tensors can be somewhat confusing.
|
| 249 |
+
# Some of these codepaths are shared in both runtime, and compile time. The presence
|
| 250 |
+
# of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
|
| 251 |
+
#
|
| 252 |
+
# A few things to keep in mind:
|
| 253 |
+
#
|
| 254 |
+
# 1) We invoke `compile_submod` with a real module. The output of that gets stored
|
| 255 |
+
# on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
|
| 256 |
+
#
|
| 257 |
+
# 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
|
| 258 |
+
# module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
|
| 259 |
+
#
|
| 260 |
+
# 3) Fake tensors should always be around during compile time.
|
| 261 |
+
#
|
| 262 |
+
# 4) Fake tensors should never be around at runtime.
|
| 263 |
+
#
|
| 264 |
+
# 5) We end up with a compilation mode that takes a real submodule and fake tensors,
|
| 265 |
+
# to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
|
| 266 |
+
def run_node(self, n: Node) -> Any:
|
| 267 |
+
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
| 268 |
+
new_args = []
|
| 269 |
+
assert self.fake_mode
|
| 270 |
+
for arg in args:
|
| 271 |
+
if isinstance(arg, torch.Tensor) and not isinstance(
|
| 272 |
+
arg, torch._subclasses.FakeTensor
|
| 273 |
+
):
|
| 274 |
+
new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
|
| 275 |
+
else:
|
| 276 |
+
new_args.append(arg)
|
| 277 |
+
|
| 278 |
+
log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
|
| 279 |
+
assert isinstance(args, tuple)
|
| 280 |
+
assert isinstance(kwargs, dict)
|
| 281 |
+
|
| 282 |
+
if n.op == "call_module":
|
| 283 |
+
real_mod = self.fetch_attr(str(n.target))
|
| 284 |
+
if self.fake_mode:
|
| 285 |
+
curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
|
| 286 |
+
else:
|
| 287 |
+
curr_submod = real_mod
|
| 288 |
+
|
| 289 |
+
ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
|
| 290 |
+
|
| 291 |
+
# When calling the compiler on the submod, inputs (new_args) are expected to
|
| 292 |
+
# be FakeTensors already since Dynamo would have made them FakeTensors in the
|
| 293 |
+
# non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
|
| 294 |
+
# since this wrapping happens during compilation
|
| 295 |
+
|
| 296 |
+
# Note: Returning Fake Tensors on First AOT Autograd Call
|
| 297 |
+
#
|
| 298 |
+
# Inductor will optimize strides of outputs when it deems it profitable.
|
| 299 |
+
# For instance, converting to channels last. When we split the graph here
|
| 300 |
+
# into multiple inductor compilations, we need to make sure that the
|
| 301 |
+
# output strides of one compilation is appropriately passed to the subsequent
|
| 302 |
+
# compilations. However, the mapping from inductor output to dynamo output
|
| 303 |
+
# is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
|
| 304 |
+
# subclass handling, etc. In order to replay all this logic we set a flag such that
|
| 305 |
+
# the first invocation of inductor in aot_autograd will return Fake Tensors with
|
| 306 |
+
# appropriate strides. Then, all of aot autograd's runtime logic is replayed.
|
| 307 |
+
# This gives us the appropriately strided outputs here which will reflect runtime strides.
|
| 308 |
+
|
| 309 |
+
class FakeifyFirstAOTInvocationGuard:
|
| 310 |
+
def __init__(self) -> None:
|
| 311 |
+
self.tc = torch._guards.TracingContext.try_get()
|
| 312 |
+
assert self.tc
|
| 313 |
+
self.tc.fakify_first_call = True
|
| 314 |
+
|
| 315 |
+
def __del__(self) -> None:
|
| 316 |
+
self.tc.fakify_first_call = False # type: ignore[union-attr]
|
| 317 |
+
|
| 318 |
+
# For aot_eager and other backends, tracing context is not set
|
| 319 |
+
has_tracing_context = torch._guards.TracingContext.try_get() is not None
|
| 320 |
+
if has_tracing_context:
|
| 321 |
+
g = FakeifyFirstAOTInvocationGuard() # noqa: F841
|
| 322 |
+
|
| 323 |
+
from torch._dynamo.utils import counters
|
| 324 |
+
|
| 325 |
+
init = counters["aot_autograd"]["total"]
|
| 326 |
+
compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
|
| 327 |
+
|
| 328 |
+
# TODO - better way of doing this?
|
| 329 |
+
# Only aot autograd handles fakifying first call
|
| 330 |
+
invoked_aot_autograd = init != counters["aot_autograd"]["total"]
|
| 331 |
+
|
| 332 |
+
# We update the original (outer) graph with a call into the compiled module
|
| 333 |
+
# instead of the uncompiled one.
|
| 334 |
+
self.module.delete_submodule(n.target) # type: ignore[operator]
|
| 335 |
+
n.target = "compiled_" + n.target # type: ignore[operator]
|
| 336 |
+
self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
|
| 337 |
+
|
| 338 |
+
# Finally, we have to produce inputs for use compiling the next submodule,
|
| 339 |
+
# and these need to be FakeTensors, so we execute the module under fake_mode
|
| 340 |
+
# Because parameters are not fake we patch fake tensor mode to allow non fake inputs
|
| 341 |
+
with (
|
| 342 |
+
self.fake_mode,
|
| 343 |
+
mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True),
|
| 344 |
+
):
|
| 345 |
+
if has_tracing_context and invoked_aot_autograd:
|
| 346 |
+
tracing_ctx = torch._guards.TracingContext.try_get()
|
| 347 |
+
assert tracing_ctx is not None
|
| 348 |
+
# DDPOptimizer maintains 1 dynamo graph -> N AOT graphs
|
| 349 |
+
# Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances
|
| 350 |
+
ddp_ctx = tracing_ctx.ddp_optimizer_ctx
|
| 351 |
+
assert ddp_ctx is not None
|
| 352 |
+
assert tracing_ctx.fw_metadata is not None
|
| 353 |
+
ddp_ctx.curr_bucket += 1
|
| 354 |
+
ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata)
|
| 355 |
+
|
| 356 |
+
out = compiled_submod_real(*new_args, **kwargs)
|
| 357 |
+
# output should be fake or subclass
|
| 358 |
+
assert all(
|
| 359 |
+
(not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
|
| 360 |
+
for t in (out if isinstance(out, (list, tuple)) else [out])
|
| 361 |
+
)
|
| 362 |
+
return out
|
| 363 |
+
else:
|
| 364 |
+
return curr_submod(*new_args, **kwargs)
|
| 365 |
+
else:
|
| 366 |
+
# placeholder or output nodes don't need to get compiled, just executed
|
| 367 |
+
return getattr(self, n.op)(n.target, new_args, kwargs)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class DDPOptimizer:
|
| 371 |
+
"""Note [DDPOptimizer]
|
| 372 |
+
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
|
| 373 |
+
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
|
| 374 |
+
the boundaries of gradient-allreduce buckets chosen by DDP.
|
| 375 |
+
|
| 376 |
+
Background/Motivation
|
| 377 |
+
- DDP uses allreduce collectives to synchronize partial gradients computed on different workers
|
| 378 |
+
- DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
|
| 379 |
+
- Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
|
| 380 |
+
at around the same time during backward and thus can share the same allreduce efficiently
|
| 381 |
+
- Allreduces must overlap with backward compute for optimal training performance
|
| 382 |
+
- DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
|
| 383 |
+
operates when individual grads become 'ready'
|
| 384 |
+
- Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
|
| 385 |
+
autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
|
| 386 |
+
fused backward function executes, preventing any overlap of compute and communication
|
| 387 |
+
|
| 388 |
+
Algorithm
|
| 389 |
+
- DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
|
| 390 |
+
this graph in reverse order to determine the true order that gradients will become ready during backward.
|
| 391 |
+
- Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
|
| 392 |
+
and a graph break introduced
|
| 393 |
+
- Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
|
| 394 |
+
into an outer module that is returned to the user
|
| 395 |
+
|
| 396 |
+
Notes
|
| 397 |
+
- It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
|
| 398 |
+
and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
|
| 399 |
+
in eager.
|
| 400 |
+
- If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
|
| 401 |
+
produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
|
| 402 |
+
degradation approaching the baseline case where graph-splits are not used, but not worse.
|
| 403 |
+
- If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
|
| 404 |
+
subgraphs being compiled
|
| 405 |
+
- DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
|
| 406 |
+
left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
|
| 407 |
+
also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
|
| 408 |
+
it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
|
| 409 |
+
- DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
|
| 410 |
+
and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
|
| 411 |
+
DDPOptimizer)
|
| 412 |
+
|
| 413 |
+
Debugging
|
| 414 |
+
- Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
|
| 415 |
+
- In many cases, the log messages are helpful (they show bucket size assignments)-
|
| 416 |
+
just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
|
| 417 |
+
- See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
|
| 418 |
+
in a single process (or with torchrun, in multiple processes)
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
|
| 422 |
+
set to match the equivalent parameter on the original DDP module.
|
| 423 |
+
|
| 424 |
+
backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
|
| 425 |
+
|
| 426 |
+
first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
|
| 427 |
+
special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
|
| 428 |
+
|
| 429 |
+
"""
|
| 430 |
+
|
| 431 |
+
def __init__(
|
| 432 |
+
self,
|
| 433 |
+
bucket_bytes_cap: int,
|
| 434 |
+
backend_compile_fn: CompilerFn,
|
| 435 |
+
first_bucket_cap: Optional[int] = None,
|
| 436 |
+
) -> None:
|
| 437 |
+
if first_bucket_cap is not None:
|
| 438 |
+
self.first_bucket_cap = first_bucket_cap
|
| 439 |
+
elif torch.distributed.is_available():
|
| 440 |
+
# this constant comes from C10D lib which is not always built
|
| 441 |
+
self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
|
| 442 |
+
else:
|
| 443 |
+
self.first_bucket_cap = bucket_bytes_cap
|
| 444 |
+
|
| 445 |
+
self.bucket_bytes_cap = bucket_bytes_cap
|
| 446 |
+
assert self.first_bucket_cap <= self.bucket_bytes_cap, (
|
| 447 |
+
"First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.backend_compile_fn = backend_compile_fn
|
| 451 |
+
|
| 452 |
+
def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
|
| 453 |
+
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
|
| 454 |
+
|
| 455 |
+
def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
|
| 456 |
+
bucket.size += param.untyped_storage().nbytes()
|
| 457 |
+
bucket.params.append(name)
|
| 458 |
+
bucket.param_ids.append(id(param))
|
| 459 |
+
|
| 460 |
+
def add_module_params_to_bucket(
|
| 461 |
+
self,
|
| 462 |
+
mod: torch.nn.Module,
|
| 463 |
+
bucket: Bucket,
|
| 464 |
+
processed_modules: set[torch.nn.Module],
|
| 465 |
+
prefix: str,
|
| 466 |
+
) -> None:
|
| 467 |
+
processed_modules.add(mod)
|
| 468 |
+
for name, param in mod.named_parameters():
|
| 469 |
+
if param.requires_grad and not self._ignore_parameter(param):
|
| 470 |
+
self.add_param(bucket, param, f"{prefix}_{name}")
|
| 471 |
+
|
| 472 |
+
def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
|
| 473 |
+
for arg in node.args:
|
| 474 |
+
if not isinstance(arg, torch.fx.node.Node):
|
| 475 |
+
continue
|
| 476 |
+
if arg.op != "placeholder":
|
| 477 |
+
continue
|
| 478 |
+
param = arg.meta["example_value"]
|
| 479 |
+
if (
|
| 480 |
+
isinstance(param, torch.nn.Parameter)
|
| 481 |
+
and param.requires_grad
|
| 482 |
+
and not self._ignore_parameter(param)
|
| 483 |
+
):
|
| 484 |
+
self.add_param(bucket, param, str(arg.target))
|
| 485 |
+
|
| 486 |
+
def compile_fn(
|
| 487 |
+
self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
|
| 488 |
+
) -> CompiledFn:
|
| 489 |
+
"""
|
| 490 |
+
Implements graph splitting, first determining a set of of buckets by counting
|
| 491 |
+
parameter sizes in reverse graph order, then invoking the user/backend compiler
|
| 492 |
+
to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
|
| 493 |
+
and returns its callable.
|
| 494 |
+
"""
|
| 495 |
+
# 1: compute the partition map according to DDP bucket logic
|
| 496 |
+
buckets = [Bucket()] # (size, param_names)
|
| 497 |
+
processed_modules: set[torch.nn.Module] = set()
|
| 498 |
+
for node in reversed(gm.graph.nodes):
|
| 499 |
+
if node.op in ("output", "placeholder"):
|
| 500 |
+
continue
|
| 501 |
+
|
| 502 |
+
if (
|
| 503 |
+
buckets[0].size >= self.bucket_bytes_cap
|
| 504 |
+
or len(buckets) == 1
|
| 505 |
+
and buckets[0].size >= self.first_bucket_cap
|
| 506 |
+
):
|
| 507 |
+
if bucket_has_external_output(buckets[0]):
|
| 508 |
+
buckets.insert(0, Bucket())
|
| 509 |
+
else:
|
| 510 |
+
# continue building this bucket past the point of filling its parameter capacity,
|
| 511 |
+
# to increase chances it contains at least one node that is either a global output or
|
| 512 |
+
# passed as input to a subsequent graph
|
| 513 |
+
|
| 514 |
+
if buckets[0].opcount_increased_to_capture_external_output == 0:
|
| 515 |
+
buckets[0].paramsize_before_opcount_increase = buckets[0].size
|
| 516 |
+
buckets[0].opcount_increased_to_capture_external_output += 1
|
| 517 |
+
|
| 518 |
+
if node.op == "call_function":
|
| 519 |
+
self.add_param_args(buckets[0], node)
|
| 520 |
+
|
| 521 |
+
elif node.op == "call_module":
|
| 522 |
+
target_mod = gm.get_submodule(node.target)
|
| 523 |
+
if target_mod not in processed_modules:
|
| 524 |
+
self.add_module_params_to_bucket(
|
| 525 |
+
target_mod, buckets[0], processed_modules, node.target
|
| 526 |
+
)
|
| 527 |
+
elif node.op == "call_method":
|
| 528 |
+
if isinstance(node.args[0].target, str):
|
| 529 |
+
target_mod = None
|
| 530 |
+
try:
|
| 531 |
+
target_mod = gm.get_submodule(node.args[0].target)
|
| 532 |
+
except AttributeError:
|
| 533 |
+
pass
|
| 534 |
+
if target_mod is not None and target_mod not in processed_modules:
|
| 535 |
+
self.add_module_params_to_bucket(
|
| 536 |
+
target_mod, buckets[0], processed_modules, node.target
|
| 537 |
+
)
|
| 538 |
+
# This handles situations like tmp = torch.mm(x, self.weight.t())
|
| 539 |
+
# t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
|
| 540 |
+
# tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
|
| 541 |
+
self.add_param_args(buckets[0], node)
|
| 542 |
+
|
| 543 |
+
elif node.op == "get_attr":
|
| 544 |
+
maybe_param = getattr(gm, node.target)
|
| 545 |
+
if (
|
| 546 |
+
isinstance(maybe_param, torch.nn.Parameter)
|
| 547 |
+
and maybe_param.requires_grad
|
| 548 |
+
and not self._ignore_parameter(maybe_param)
|
| 549 |
+
):
|
| 550 |
+
self.add_param(buckets[0], maybe_param, node.target)
|
| 551 |
+
|
| 552 |
+
# All nodes have to be mapped to a bucket, even if they don't have their own params
|
| 553 |
+
# Ignored params still end up in buckets, we just don't count them towards the capacity
|
| 554 |
+
buckets[0].nodes.append(node)
|
| 555 |
+
|
| 556 |
+
if len(buckets) > 1 and buckets[0].size == 0:
|
| 557 |
+
# we collected a small preamble graph with ops that don't include parameters, fuse it back
|
| 558 |
+
buckets[1].nodes.extend(buckets[0].nodes)
|
| 559 |
+
assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
|
| 560 |
+
del buckets[0]
|
| 561 |
+
|
| 562 |
+
# stash buckets for testing/debugging purposes
|
| 563 |
+
self.buckets = buckets
|
| 564 |
+
pretty_print_buckets(buckets, self.bucket_bytes_cap)
|
| 565 |
+
|
| 566 |
+
if len(buckets) == 1:
|
| 567 |
+
# bypass split/fuse logic if there is only one bucket
|
| 568 |
+
return self.backend_compile_fn(gm, example_inputs)
|
| 569 |
+
|
| 570 |
+
# 2: partition the graphmodule according to bucket capacity
|
| 571 |
+
partition_map = {}
|
| 572 |
+
for idx, b in enumerate(buckets):
|
| 573 |
+
for node in b.nodes:
|
| 574 |
+
partition_map[node] = idx
|
| 575 |
+
|
| 576 |
+
split_gm = fx.passes.split_module.split_module(
|
| 577 |
+
gm,
|
| 578 |
+
None, # type: ignore[arg-type]
|
| 579 |
+
lambda node: partition_map[node],
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# See note [Assumption on Dynamo Metadata]
|
| 583 |
+
propagate_dynamo_source(gm, split_gm)
|
| 584 |
+
propagate_metadata(gm, split_gm)
|
| 585 |
+
|
| 586 |
+
debug_str = (
|
| 587 |
+
f"\n---orig graph---\n{gm.graph}\n"
|
| 588 |
+
+ f"\n---split graph---\n{split_gm.graph}\n"
|
| 589 |
+
)
|
| 590 |
+
for name, module in split_gm.named_modules():
|
| 591 |
+
if "." not in name and len(name):
|
| 592 |
+
# only print the submod graphs, not their children
|
| 593 |
+
debug_str += f"\n---{name} graph---\n{module.graph}\n"
|
| 594 |
+
debug_str += "\n---------------\n"
|
| 595 |
+
ddp_graph_log.debug(debug_str)
|
| 596 |
+
|
| 597 |
+
trace_structured(
|
| 598 |
+
"optimize_ddp_split_graph",
|
| 599 |
+
payload_fn=lambda: split_gm.print_readable(print_output=False),
|
| 600 |
+
)
|
| 601 |
+
for name, module in split_gm.named_modules():
|
| 602 |
+
if "." not in name and len(name):
|
| 603 |
+
trace_structured(
|
| 604 |
+
"optimize_ddp_split_child",
|
| 605 |
+
lambda: {"name": name},
|
| 606 |
+
payload_fn=lambda: module.print_readable(print_output=False),
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
fake_mode = detect_fake_mode(example_inputs)
|
| 610 |
+
if fake_mode is None:
|
| 611 |
+
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
|
| 612 |
+
|
| 613 |
+
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
|
| 614 |
+
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
|
| 615 |
+
submod_compiler.run(*example_inputs)
|
| 616 |
+
split_gm.recompile()
|
| 617 |
+
|
| 618 |
+
ddp_graph_log.debug(
|
| 619 |
+
"\n---final graph---\n%s\n---------------\n", split_gm.graph
|
| 620 |
+
)
|
| 621 |
+
return split_gm
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/onnxrt.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This backend is maintained by ONNX team. To direct issues
|
| 2 |
+
# to the right people, please tag related GitHub issues with `module: onnx`.
|
| 3 |
+
#
|
| 4 |
+
# Maintainers' Github IDs: wschin, xadupre
|
| 5 |
+
# from torch.onnx._internal.onnxruntime import (
|
| 6 |
+
# is_onnxrt_backend_supported,
|
| 7 |
+
# torch_compile_backend,
|
| 8 |
+
# )
|
| 9 |
+
|
| 10 |
+
# from .registry import register_backend
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
Placeholder for onnxruntime backend for dynamo
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# def has_onnxruntime():
|
| 17 |
+
# # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported()
|
| 18 |
+
# return is_onnxrt_backend_supported()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# if is_onnxrt_backend_supported():
|
| 22 |
+
# register_backend(name="onnxrt", compiler_fn=torch_compile_backend)
|
| 23 |
+
# else:
|
| 24 |
+
|
| 25 |
+
# def information_displaying_backend(*args, **kwargs):
|
| 26 |
+
# raise ImportError(
|
| 27 |
+
# "onnxrt is not registered as a backend. "
|
| 28 |
+
# "Please make sure all dependencies such as "
|
| 29 |
+
# "numpy, onnx, onnxscript, and onnxruntime-training are installed. "
|
| 30 |
+
# "Suggested procedure to fix dependency problem:\n"
|
| 31 |
+
# " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n"
|
| 32 |
+
# " (2) Open a new python terminal.\n"
|
| 33 |
+
# " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n"
|
| 34 |
+
# " (4) If it returns `True`, then you can use `onnxrt` backend.\n"
|
| 35 |
+
# " (5) If it returns `False`, please execute the package importing section in "
|
| 36 |
+
# "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails."
|
| 37 |
+
# )
|
| 38 |
+
|
| 39 |
+
# register_backend(name="onnxrt", compiler_fn=information_displaying_backend)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/registry.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module implements TorchDynamo's backend registry system for managing compiler backends.
|
| 3 |
+
|
| 4 |
+
The registry provides a centralized way to register, discover and manage different compiler
|
| 5 |
+
backends that can be used with torch.compile(). It handles:
|
| 6 |
+
|
| 7 |
+
- Backend registration and discovery through decorators and entry points
|
| 8 |
+
- Lazy loading of backend implementations
|
| 9 |
+
- Lookup and validation of backend names
|
| 10 |
+
- Categorization of backends using tags (debug, experimental, etc.)
|
| 11 |
+
|
| 12 |
+
Key components:
|
| 13 |
+
- CompilerFn: Type for backend compiler functions that transform FX graphs
|
| 14 |
+
- _BACKENDS: Registry mapping backend names to entry points
|
| 15 |
+
- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
|
| 16 |
+
|
| 17 |
+
Example usage:
|
| 18 |
+
@register_backend
|
| 19 |
+
def my_compiler(fx_graph, example_inputs):
|
| 20 |
+
# Transform FX graph into optimized implementation
|
| 21 |
+
return compiled_fn
|
| 22 |
+
|
| 23 |
+
# Use registered backend
|
| 24 |
+
torch.compile(model, backend="my_compiler")
|
| 25 |
+
|
| 26 |
+
The registry also supports discovering backends through setuptools entry points
|
| 27 |
+
in the "torch_dynamo_backends" group. Example:
|
| 28 |
+
```
|
| 29 |
+
setup.py
|
| 30 |
+
---
|
| 31 |
+
from setuptools import setup
|
| 32 |
+
|
| 33 |
+
setup(
|
| 34 |
+
name='my_torch_backend',
|
| 35 |
+
version='0.1',
|
| 36 |
+
packages=['my_torch_backend'],
|
| 37 |
+
entry_points={
|
| 38 |
+
'torch_dynamo_backends': [
|
| 39 |
+
# name = path to entry point of backend implementation
|
| 40 |
+
'my_compiler = my_torch_backend.compiler:my_compiler_function',
|
| 41 |
+
],
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
```
|
| 45 |
+
```
|
| 46 |
+
my_torch_backend/compiler.py
|
| 47 |
+
---
|
| 48 |
+
def my_compiler_function(fx_graph, example_inputs):
|
| 49 |
+
# Transform FX graph into optimized implementation
|
| 50 |
+
return compiled_fn
|
| 51 |
+
```
|
| 52 |
+
Using `my_compiler` backend:
|
| 53 |
+
```
|
| 54 |
+
import torch
|
| 55 |
+
|
| 56 |
+
model = ... # Your PyTorch model
|
| 57 |
+
optimized_model = torch.compile(model, backend="my_compiler")
|
| 58 |
+
```
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
import functools
|
| 62 |
+
import logging
|
| 63 |
+
from collections.abc import Callable, Sequence
|
| 64 |
+
from importlib.metadata import EntryPoint
|
| 65 |
+
from typing import Any, Optional, Protocol, Union
|
| 66 |
+
|
| 67 |
+
import torch
|
| 68 |
+
from torch import fx
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
log = logging.getLogger(__name__)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CompiledFn(Protocol):
|
| 75 |
+
def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]
|
| 79 |
+
|
| 80 |
+
_BACKENDS: dict[str, Optional[EntryPoint]] = {}
|
| 81 |
+
_COMPILER_FNS: dict[str, CompilerFn] = {}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def register_backend(
|
| 85 |
+
compiler_fn: Optional[CompilerFn] = None,
|
| 86 |
+
name: Optional[str] = None,
|
| 87 |
+
tags: Sequence[str] = (),
|
| 88 |
+
) -> Callable[..., Any]:
|
| 89 |
+
"""
|
| 90 |
+
Decorator to add a given compiler to the registry to allow calling
|
| 91 |
+
`torch.compile` with string shorthand. Note: for projects not
|
| 92 |
+
imported by default, it might be easier to pass a function directly
|
| 93 |
+
as a backend and not use a string.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
compiler_fn: Callable taking a FX graph and fake tensor inputs
|
| 97 |
+
name: Optional name, defaults to `compiler_fn.__name__`
|
| 98 |
+
tags: Optional set of string tags to categorize backend with
|
| 99 |
+
"""
|
| 100 |
+
if compiler_fn is None:
|
| 101 |
+
# @register_backend(name="") syntax
|
| 102 |
+
return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
|
| 103 |
+
assert callable(compiler_fn)
|
| 104 |
+
name = name or compiler_fn.__name__
|
| 105 |
+
assert name not in _COMPILER_FNS, f"duplicate name: {name}"
|
| 106 |
+
if compiler_fn not in _BACKENDS:
|
| 107 |
+
_BACKENDS[name] = None
|
| 108 |
+
_COMPILER_FNS[name] = compiler_fn
|
| 109 |
+
compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
|
| 110 |
+
return compiler_fn
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
register_debug_backend = functools.partial(register_backend, tags=("debug",))
|
| 114 |
+
register_experimental_backend = functools.partial(
|
| 115 |
+
register_backend, tags=("experimental",)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
|
| 120 |
+
"""Expand backend strings to functions"""
|
| 121 |
+
if isinstance(compiler_fn, str):
|
| 122 |
+
if compiler_fn not in _BACKENDS:
|
| 123 |
+
_lazy_import()
|
| 124 |
+
if compiler_fn not in _BACKENDS:
|
| 125 |
+
from ..exc import InvalidBackend
|
| 126 |
+
|
| 127 |
+
raise InvalidBackend(name=compiler_fn)
|
| 128 |
+
|
| 129 |
+
if compiler_fn not in _COMPILER_FNS:
|
| 130 |
+
entry_point = _BACKENDS[compiler_fn]
|
| 131 |
+
if entry_point is not None:
|
| 132 |
+
register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
|
| 133 |
+
compiler_fn = _COMPILER_FNS[compiler_fn]
|
| 134 |
+
return compiler_fn
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# NOTE: can't type this due to public api mismatch; follow up with dev team
|
| 138 |
+
def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def]
|
| 139 |
+
"""
|
| 140 |
+
Return valid strings that can be passed to:
|
| 141 |
+
|
| 142 |
+
torch.compile(..., backend="name")
|
| 143 |
+
"""
|
| 144 |
+
_lazy_import()
|
| 145 |
+
exclude_tags_set = set(exclude_tags or ())
|
| 146 |
+
|
| 147 |
+
backends = [
|
| 148 |
+
name
|
| 149 |
+
for name in _BACKENDS
|
| 150 |
+
if name not in _COMPILER_FNS
|
| 151 |
+
or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
|
| 152 |
+
]
|
| 153 |
+
return sorted(backends)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@functools.cache
|
| 157 |
+
def _lazy_import() -> None:
|
| 158 |
+
from .. import backends
|
| 159 |
+
from ..utils import import_submodule
|
| 160 |
+
|
| 161 |
+
import_submodule(backends)
|
| 162 |
+
|
| 163 |
+
from ..repro.after_dynamo import dynamo_minifier_backend
|
| 164 |
+
|
| 165 |
+
assert dynamo_minifier_backend is not None
|
| 166 |
+
|
| 167 |
+
_discover_entrypoint_backends()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@functools.cache
|
| 171 |
+
def _discover_entrypoint_backends() -> None:
|
| 172 |
+
# importing here so it will pick up the mocked version in test_backends.py
|
| 173 |
+
from importlib.metadata import entry_points
|
| 174 |
+
|
| 175 |
+
group_name = "torch_dynamo_backends"
|
| 176 |
+
eps = entry_points(group=group_name)
|
| 177 |
+
eps_dict = {name: eps[name] for name in eps.names}
|
| 178 |
+
for backend_name in eps_dict:
|
| 179 |
+
_BACKENDS[backend_name] = eps_dict[backend_name]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/torchxla.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from functorch.compile import make_boxed_func
|
| 7 |
+
from torch import fx
|
| 8 |
+
|
| 9 |
+
from ..backends.common import aot_autograd
|
| 10 |
+
from .registry import CompiledFn, register_backend, register_experimental_backend
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
log = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@register_experimental_backend
|
| 17 |
+
def openxla_eval(
|
| 18 |
+
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
| 19 |
+
) -> CompiledFn:
|
| 20 |
+
return xla_backend_helper(model, fake_tensor_inputs, boxed=False)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def openxla_eval_boxed(
|
| 24 |
+
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor]
|
| 25 |
+
) -> Callable[..., Any]:
|
| 26 |
+
return xla_backend_helper(model, fake_tensor_inputs, boxed=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def xla_backend_helper(
|
| 30 |
+
model: fx.GraphModule, fake_tensor_inputs: list[torch.Tensor], boxed: bool = False
|
| 31 |
+
) -> Callable[..., Any]:
|
| 32 |
+
try:
|
| 33 |
+
import torch_xla.core.dynamo_bridge as bridge
|
| 34 |
+
except ImportError as e:
|
| 35 |
+
raise ImportError(
|
| 36 |
+
"Please follow the instruction in https://github.com/pytorch/xla#pytorchxla to install torch_xla"
|
| 37 |
+
) from e
|
| 38 |
+
|
| 39 |
+
compiled_graph = None
|
| 40 |
+
|
| 41 |
+
def fwd(*args: torch.Tensor) -> Any:
|
| 42 |
+
nonlocal model
|
| 43 |
+
nonlocal compiled_graph
|
| 44 |
+
if compiled_graph is None:
|
| 45 |
+
compiled_graph = bridge.extract_compiled_graph(model, args)
|
| 46 |
+
del model
|
| 47 |
+
return compiled_graph(*args)
|
| 48 |
+
|
| 49 |
+
return make_boxed_func(fwd) if boxed else fwd
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
openxla = aot_autograd(
|
| 53 |
+
fw_compiler=openxla_eval_boxed,
|
| 54 |
+
)
|
| 55 |
+
register_backend(name="openxla", compiler_fn=openxla)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (8.73 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/converter.cpython-312.pyc
ADDED
|
Binary file (78.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/error.cpython-312.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/non_strict_utils.cpython-312.pyc
ADDED
|
Binary file (47.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/pass_base.cpython-312.pyc
ADDED
|
Binary file (26 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/tools.cpython-312.pyc
ADDED
|
Binary file (6.17 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (73 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/verifier.cpython-312.pyc
ADDED
|
Binary file (27.2 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/__pycache__/wrappers.cpython-312.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/case.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
import re
|
| 4 |
+
import string
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
+
from types import ModuleType
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
_TAGS: dict[str, dict[str, Any]] = {
|
| 13 |
+
"torch": {
|
| 14 |
+
"cond": {},
|
| 15 |
+
"dynamic-shape": {},
|
| 16 |
+
"escape-hatch": {},
|
| 17 |
+
"map": {},
|
| 18 |
+
"dynamic-value": {},
|
| 19 |
+
"operator": {},
|
| 20 |
+
"mutation": {},
|
| 21 |
+
},
|
| 22 |
+
"python": {
|
| 23 |
+
"assert": {},
|
| 24 |
+
"builtin": {},
|
| 25 |
+
"closure": {},
|
| 26 |
+
"context-manager": {},
|
| 27 |
+
"control-flow": {},
|
| 28 |
+
"data-structure": {},
|
| 29 |
+
"standard-library": {},
|
| 30 |
+
"object-model": {},
|
| 31 |
+
},
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SupportLevel(Enum):
|
| 36 |
+
"""
|
| 37 |
+
Indicates at what stage the feature
|
| 38 |
+
used in the example is handled in export.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
SUPPORTED = 1
|
| 42 |
+
NOT_SUPPORTED_YET = 0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
ArgsType = tuple[Any, ...]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def check_inputs_type(args, kwargs):
|
| 49 |
+
if not isinstance(args, tuple):
|
| 50 |
+
raise ValueError(
|
| 51 |
+
f"Expecting args type to be a tuple, got: {type(args)}"
|
| 52 |
+
)
|
| 53 |
+
if not isinstance(kwargs, dict):
|
| 54 |
+
raise ValueError(
|
| 55 |
+
f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
|
| 56 |
+
)
|
| 57 |
+
for key in kwargs:
|
| 58 |
+
if not isinstance(key, str):
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Expecting kwargs keys to be a string, got: {type(key)}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def _validate_tag(tag: str):
|
| 64 |
+
parts = tag.split(".")
|
| 65 |
+
t = _TAGS
|
| 66 |
+
for part in parts:
|
| 67 |
+
assert set(part) <= set(
|
| 68 |
+
string.ascii_lowercase + "-"
|
| 69 |
+
), f"Tag contains invalid characters: {part}"
|
| 70 |
+
if part in t:
|
| 71 |
+
t = t[part]
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Tag {tag} is not found in registered tags.")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass(frozen=True)
|
| 77 |
+
class ExportCase:
|
| 78 |
+
example_args: ArgsType
|
| 79 |
+
description: str # A description of the use case.
|
| 80 |
+
model: torch.nn.Module
|
| 81 |
+
name: str
|
| 82 |
+
example_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 83 |
+
extra_args: Optional[ArgsType] = None # For testing graph generalization.
|
| 84 |
+
# Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
|
| 85 |
+
tags: set[str] = field(default_factory=set)
|
| 86 |
+
support_level: SupportLevel = SupportLevel.SUPPORTED
|
| 87 |
+
dynamic_shapes: Optional[dict[str, Any]] = None
|
| 88 |
+
|
| 89 |
+
def __post_init__(self):
|
| 90 |
+
check_inputs_type(self.example_args, self.example_kwargs)
|
| 91 |
+
if self.extra_args is not None:
|
| 92 |
+
check_inputs_type(self.extra_args, {})
|
| 93 |
+
|
| 94 |
+
for tag in self.tags:
|
| 95 |
+
_validate_tag(tag)
|
| 96 |
+
|
| 97 |
+
if not isinstance(self.description, str) or len(self.description) == 0:
|
| 98 |
+
raise ValueError(f'Invalid description: "{self.description}"')
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
_EXAMPLE_CASES: dict[str, ExportCase] = {}
|
| 102 |
+
_MODULES: set[ModuleType] = set()
|
| 103 |
+
_EXAMPLE_CONFLICT_CASES: dict[str, list[ExportCase]] = {}
|
| 104 |
+
_EXAMPLE_REWRITE_CASES: dict[str, list[ExportCase]] = {}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def register_db_case(case: ExportCase) -> None:
|
| 108 |
+
"""
|
| 109 |
+
Registers a user provided ExportCase into example bank.
|
| 110 |
+
"""
|
| 111 |
+
if case.name in _EXAMPLE_CASES:
|
| 112 |
+
if case.name not in _EXAMPLE_CONFLICT_CASES:
|
| 113 |
+
_EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
|
| 114 |
+
_EXAMPLE_CONFLICT_CASES[case.name].append(case)
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
_EXAMPLE_CASES[case.name] = case
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def to_snake_case(name):
|
| 121 |
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
| 122 |
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _make_export_case(m, name, configs):
|
| 126 |
+
if not isinstance(m, torch.nn.Module):
|
| 127 |
+
raise TypeError("Export case class should be a torch.nn.Module.")
|
| 128 |
+
|
| 129 |
+
if "description" not in configs:
|
| 130 |
+
# Fallback to docstring if description is missing.
|
| 131 |
+
assert (
|
| 132 |
+
m.__doc__ is not None
|
| 133 |
+
), f"Could not find description or docstring for export case: {m}"
|
| 134 |
+
configs = {**configs, "description": m.__doc__}
|
| 135 |
+
# pyrefly: ignore [bad-argument-type]
|
| 136 |
+
return ExportCase(**{**configs, "model": m, "name": name})
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def export_case(**kwargs):
|
| 140 |
+
"""
|
| 141 |
+
Decorator for registering a user provided case into example bank.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def wrapper(m):
|
| 145 |
+
configs = kwargs
|
| 146 |
+
module = inspect.getmodule(m)
|
| 147 |
+
if module in _MODULES:
|
| 148 |
+
raise RuntimeError("export_case should only be used once per example file.")
|
| 149 |
+
|
| 150 |
+
assert module is not None
|
| 151 |
+
_MODULES.add(module)
|
| 152 |
+
module_name = module.__name__.split(".")[-1]
|
| 153 |
+
case = _make_export_case(m, module_name, configs)
|
| 154 |
+
register_db_case(case)
|
| 155 |
+
return case
|
| 156 |
+
|
| 157 |
+
return wrapper
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def export_rewrite_case(**kwargs):
|
| 161 |
+
def wrapper(m):
|
| 162 |
+
configs = kwargs
|
| 163 |
+
|
| 164 |
+
parent = configs.pop("parent")
|
| 165 |
+
assert isinstance(parent, ExportCase)
|
| 166 |
+
key = parent.name
|
| 167 |
+
if key not in _EXAMPLE_REWRITE_CASES:
|
| 168 |
+
_EXAMPLE_REWRITE_CASES[key] = []
|
| 169 |
+
|
| 170 |
+
configs["example_args"] = parent.example_args
|
| 171 |
+
case = _make_export_case(m, to_snake_case(m.__name__), configs)
|
| 172 |
+
_EXAMPLE_REWRITE_CASES[key].append(case)
|
| 173 |
+
return case
|
| 174 |
+
|
| 175 |
+
return wrapper
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/gen_example.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch._export.db.examples as examples
|
| 5 |
+
|
| 6 |
+
TEMPLATE = '''import torch
|
| 7 |
+
|
| 8 |
+
def {case_name}(x):
|
| 9 |
+
"""
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
return
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
assert len(sys.argv) == 2
|
| 17 |
+
root_dir = examples.__name__.replace(".", "/")
|
| 18 |
+
assert os.path.exists(root_dir)
|
| 19 |
+
with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f:
|
| 20 |
+
print("Writing to", f.name, "...")
|
| 21 |
+
f.write(TEMPLATE.format(case_name=sys.argv[1]))
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/db/logging.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
def exportdb_error_message(case_name: str) -> str:
|
| 4 |
+
from .examples import all_examples
|
| 5 |
+
from torch._utils_internal import log_export_usage
|
| 6 |
+
|
| 7 |
+
ALL_EXAMPLES = all_examples()
|
| 8 |
+
# Detect whether case_name is really registered in exportdb.
|
| 9 |
+
if case_name in ALL_EXAMPLES:
|
| 10 |
+
url_case_name = case_name.replace("_", "-")
|
| 11 |
+
return f"See {case_name} in exportdb for unsupported case. \
|
| 12 |
+
https://pytorch.org/docs/main/generated/exportdb/index.html#{url_case_name}"
|
| 13 |
+
else:
|
| 14 |
+
log_export_usage(
|
| 15 |
+
event="export.error.casenotregistered",
|
| 16 |
+
message=case_name,
|
| 17 |
+
)
|
| 18 |
+
return f"{case_name} is unsupported."
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_class_if_classified_error(e: Exception) -> Optional[str]:
|
| 22 |
+
"""
|
| 23 |
+
Returns a string case name if the export error e is classified.
|
| 24 |
+
Returns None otherwise.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from torch._dynamo.exc import TorchRuntimeError, Unsupported, UserError
|
| 28 |
+
|
| 29 |
+
ALWAYS_CLASSIFIED = "always_classified"
|
| 30 |
+
DEFAULT_CLASS_SIGIL = "case_name"
|
| 31 |
+
|
| 32 |
+
# add error types that should be classified, along with any attribute name
|
| 33 |
+
# whose presence acts like a sigil to further distinguish which errors of
|
| 34 |
+
# that type should be classified. If the attribute name is None, then the
|
| 35 |
+
# error type is always classified.
|
| 36 |
+
_ALLOW_LIST = {
|
| 37 |
+
Unsupported: DEFAULT_CLASS_SIGIL,
|
| 38 |
+
UserError: DEFAULT_CLASS_SIGIL,
|
| 39 |
+
TorchRuntimeError: None,
|
| 40 |
+
}
|
| 41 |
+
if type(e) in _ALLOW_LIST:
|
| 42 |
+
# pyrefly: ignore [index-error]
|
| 43 |
+
attr_name = _ALLOW_LIST[type(e)]
|
| 44 |
+
if attr_name is None:
|
| 45 |
+
return ALWAYS_CLASSIFIED
|
| 46 |
+
return getattr(e, attr_name, None)
|
| 47 |
+
return None
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/node_metadata.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
NodeMetadataValue = Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
PROTECTED_KEYS: set[str] = {
|
| 8 |
+
"val",
|
| 9 |
+
"stack_trace",
|
| 10 |
+
"nn_module_stack",
|
| 11 |
+
"debug_handle",
|
| 12 |
+
"tensor_meta",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NodeMetadata:
|
| 17 |
+
def __init__(self, data: dict[str, Any]) -> None:
|
| 18 |
+
self.data: dict[str, Any] = data.copy()
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, key: str) -> NodeMetadataValue:
|
| 21 |
+
return self.data[key]
|
| 22 |
+
|
| 23 |
+
def __setitem__(self, key: str, value: NodeMetadataValue) -> NodeMetadataValue:
|
| 24 |
+
if key in PROTECTED_KEYS:
|
| 25 |
+
raise RuntimeError(f"Could not override node key: {key}")
|
| 26 |
+
self.data[key] = value
|
| 27 |
+
|
| 28 |
+
def __contains__(self, key: str) -> bool:
|
| 29 |
+
return key in self.data
|
| 30 |
+
|
| 31 |
+
def copy(self) -> "NodeMetadata":
|
| 32 |
+
return NodeMetadata(self.data.copy())
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/pass_infra/proxy_value.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pyre-strict
|
| 2 |
+
from collections.abc import Iterable, Iterator
|
| 3 |
+
from typing import Generic, TypeVar, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_T = TypeVar("_T")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ProxyValue(Generic[_T]):
|
| 12 |
+
# pyre-ignore
|
| 13 |
+
def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]):
|
| 14 |
+
# pyre-ignore
|
| 15 |
+
self.data = data
|
| 16 |
+
self.proxy_or_node = proxy
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def node(self) -> torch.fx.Node:
|
| 20 |
+
if isinstance(self.proxy_or_node, torch.fx.Node):
|
| 21 |
+
return self.proxy_or_node
|
| 22 |
+
assert isinstance(self.proxy_or_node, torch.fx.Proxy)
|
| 23 |
+
return self.proxy_or_node.node
|
| 24 |
+
|
| 25 |
+
@property
|
| 26 |
+
def proxy(self) -> torch.fx.Proxy:
|
| 27 |
+
if not isinstance(self.proxy_or_node, torch.fx.Proxy):
|
| 28 |
+
raise RuntimeError(
|
| 29 |
+
f"ProxyValue doesn't have attached Proxy object. Node: {self.proxy_or_node.format_node()}"
|
| 30 |
+
)
|
| 31 |
+
return self.proxy_or_node
|
| 32 |
+
|
| 33 |
+
def to_tensor(self) -> torch.Tensor:
|
| 34 |
+
assert isinstance(self.data, torch.Tensor)
|
| 35 |
+
return self.data
|
| 36 |
+
|
| 37 |
+
def is_tensor(self) -> bool:
|
| 38 |
+
return isinstance(self.data, torch.Tensor)
|
| 39 |
+
|
| 40 |
+
# pyre-ignore
|
| 41 |
+
def __iter__(self) -> Iterator[_T]:
|
| 42 |
+
yield from self.data
|
| 43 |
+
|
| 44 |
+
def __bool__(self) -> bool:
|
| 45 |
+
return bool(self.data)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .replace_view_ops_with_view_copy_ops_pass import ReplaceViewOpsWithViewCopyOpsPass
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (327 Bytes). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-312.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/constant_folding.cpython-312.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-312.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-312.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-312.pyc
ADDED
|
Binary file (2.27 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-312.pyc
ADDED
|
Binary file (8.49 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-312.pyc
ADDED
|
Binary file (6.05 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-312.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-312.pyc
ADDED
|
Binary file (8.68 kB). View file
|
|
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/_node_metadata_hook.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import contextlib
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils._pytree as pytree
|
| 7 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 8 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 9 |
+
from torch.fx.graph_module import GraphModule
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _node_metadata_hook(
|
| 16 |
+
node: torch.fx.Node,
|
| 17 |
+
metadata: Optional[dict[str, Any]] = None,
|
| 18 |
+
fake_mode: Optional[FakeTensorMode] = None,
|
| 19 |
+
) -> None:
|
| 20 |
+
"""
|
| 21 |
+
Hook for adding the appropriate metadata to nodes that are created during a
|
| 22 |
+
pass using graph.create_node. An example of how to use it:
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
with _set_node_metadata_hook(gm,
|
| 26 |
+
functools.partial(_node_metadata_hook, metadata={"stack_trace": "file"})
|
| 27 |
+
):
|
| 28 |
+
pass(gm)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
This hook should not work for all generic cases -- specifically it assumes
|
| 32 |
+
that nodes being added are only call_function nodes, and copies over the
|
| 33 |
+
first argument node's nn_module_stack.
|
| 34 |
+
"""
|
| 35 |
+
# pyrefly: ignore [bad-assignment]
|
| 36 |
+
fake_mode = fake_mode or contextlib.nullcontext()
|
| 37 |
+
|
| 38 |
+
assert node.op == "call_function" and callable(node.target), (
|
| 39 |
+
f"node: {node}, target: {node.target}"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if (
|
| 43 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 44 |
+
and len(node.target._schema.returns) == 0
|
| 45 |
+
):
|
| 46 |
+
node.meta["val"] = None
|
| 47 |
+
else:
|
| 48 |
+
fake_args, fake_kwargs = pytree.tree_map_only(
|
| 49 |
+
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
|
| 50 |
+
)
|
| 51 |
+
# pyrefly: ignore [bad-context-manager]
|
| 52 |
+
with fake_mode, enable_python_dispatcher():
|
| 53 |
+
fake_res = node.target(*fake_args, **fake_kwargs)
|
| 54 |
+
node.meta["val"] = fake_res
|
| 55 |
+
|
| 56 |
+
if metadata is not None:
|
| 57 |
+
for k, v in metadata.items():
|
| 58 |
+
node.meta[k] = v
|
| 59 |
+
|
| 60 |
+
# Copy over metadata from argument nodes
|
| 61 |
+
arg_meta = [
|
| 62 |
+
arg.meta
|
| 63 |
+
for arg in pytree.tree_flatten((node.args, node.kwargs))[0]
|
| 64 |
+
if isinstance(arg, torch.fx.Node)
|
| 65 |
+
]
|
| 66 |
+
if len(arg_meta) == 0:
|
| 67 |
+
return
|
| 68 |
+
arg_meta = arg_meta[0]
|
| 69 |
+
|
| 70 |
+
node.meta["nn_module_stack"] = node.meta.get(
|
| 71 |
+
"nn_module_stack",
|
| 72 |
+
arg_meta.get(
|
| 73 |
+
"nn_module_stack",
|
| 74 |
+
{
|
| 75 |
+
_EMPTY_NN_MODULE_STACK_KEY: (
|
| 76 |
+
_EMPTY_NN_MODULE_STACK_KEY,
|
| 77 |
+
_EMPTY_NN_MODULE_STACK_KEY,
|
| 78 |
+
)
|
| 79 |
+
},
|
| 80 |
+
),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
node.meta["torch_fn"] = node.meta.get(
|
| 84 |
+
"torch_fn",
|
| 85 |
+
(
|
| 86 |
+
f"{node.target.__name__}_0",
|
| 87 |
+
# pyrefly: ignore [missing-attribute]
|
| 88 |
+
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@contextlib.contextmanager
|
| 94 |
+
def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
|
| 95 |
+
"""
|
| 96 |
+
Takes a callable which will be called after we create a new node. The
|
| 97 |
+
callable takes the newly created node as input and returns None.
|
| 98 |
+
"""
|
| 99 |
+
assert callable(f), "node_metadata_hook must be a callable."
|
| 100 |
+
|
| 101 |
+
# Add the hook to all submodules
|
| 102 |
+
for m in gm.modules():
|
| 103 |
+
if isinstance(m, GraphModule):
|
| 104 |
+
m._register_create_node_hook(f)
|
| 105 |
+
try:
|
| 106 |
+
yield
|
| 107 |
+
finally:
|
| 108 |
+
# Restore hook for all submodules
|
| 109 |
+
for m in gm.modules():
|
| 110 |
+
if isinstance(m, GraphModule):
|
| 111 |
+
m._unregister_create_node_hook(f)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import math
|
| 3 |
+
import operator
|
| 4 |
+
import traceback
|
| 5 |
+
from functools import partial
|
| 6 |
+
from typing import NamedTuple, TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
import sympy
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.fx
|
| 12 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 13 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 14 |
+
from torch.utils._sympy.numbers import int_oo
|
| 15 |
+
from torch.utils._sympy.value_ranges import ValueRanges
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__all__ = ["InputDim"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class InputDim(NamedTuple):
|
| 26 |
+
input_name: str
|
| 27 |
+
dim: int
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _convert_to_int(val):
|
| 31 |
+
# Convert simple sympy Integers into concrete int
|
| 32 |
+
if val in (sympy.oo, int_oo):
|
| 33 |
+
return math.inf
|
| 34 |
+
if val in (-sympy.oo, -int_oo):
|
| 35 |
+
return -math.inf
|
| 36 |
+
if isinstance(val, sympy.Integer):
|
| 37 |
+
return int(val)
|
| 38 |
+
raise RuntimeError("Export constraints cannot be non-integer expressions")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _convert_range_to_int(range: ValueRanges):
|
| 42 |
+
assert isinstance(range, ValueRanges)
|
| 43 |
+
min_val = _convert_to_int(range.lower)
|
| 44 |
+
max_val = _convert_to_int(range.upper)
|
| 45 |
+
return min_val, max_val
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
range_constraints: dict[sympy.Symbol, ValueRanges],
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints
|
| 55 |
+
self._asserts_generated_unbacked_symbols: set[sympy.Symbol] = set()
|
| 56 |
+
self.counter = 0
|
| 57 |
+
|
| 58 |
+
def _assert_range_constraint(self, node, lower, upper, assert_msg):
|
| 59 |
+
last_node = node
|
| 60 |
+
if lower > -math.inf:
|
| 61 |
+
last_node = self._insert_assert_async(
|
| 62 |
+
last_node, operator.ge, node, lower, assert_msg
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if upper < math.inf:
|
| 66 |
+
last_node = self._insert_assert_async(
|
| 67 |
+
last_node, operator.le, node, upper, assert_msg
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
|
| 71 |
+
"""
|
| 72 |
+
Inserts assert_async call_function nodes in the graph. This function is
|
| 73 |
+
called **during** the interpreter-based pass.
|
| 74 |
+
"""
|
| 75 |
+
self.counter += 1
|
| 76 |
+
graph = last_node.graph
|
| 77 |
+
with graph.inserting_after(last_node):
|
| 78 |
+
cmp = graph.call_function(op, (lower, upper), {})
|
| 79 |
+
with graph.inserting_after(cmp):
|
| 80 |
+
cmp_tensor = graph.call_function(
|
| 81 |
+
torch.ops.aten.scalar_tensor.default, (cmp,), {}
|
| 82 |
+
)
|
| 83 |
+
with graph.inserting_after(cmp_tensor):
|
| 84 |
+
assert_async = graph.call_function(
|
| 85 |
+
torch.ops.aten._assert_async.msg,
|
| 86 |
+
(cmp_tensor, assert_msg),
|
| 87 |
+
{},
|
| 88 |
+
)
|
| 89 |
+
return assert_async
|
| 90 |
+
|
| 91 |
+
def call(self, graph_module) -> PassResult:
|
| 92 |
+
self.existing_inline_assertions = _get_existing_inline_assertions(
|
| 93 |
+
graph_module, self.range_constraints
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
for module in graph_module.modules():
|
| 97 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 98 |
+
continue
|
| 99 |
+
for node in module.graph.nodes:
|
| 100 |
+
if node.op != "call_function":
|
| 101 |
+
continue
|
| 102 |
+
if "val" not in node.meta:
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
val = node.meta["val"]
|
| 106 |
+
# In general, we may have to deal the case such as: ret[1].shape[0].
|
| 107 |
+
# We need first find out what symbols require assertion, then we need to follow the path
|
| 108 |
+
# from ret to the symbol, construct the proxies along the way and construct the messages
|
| 109 |
+
# piece-wise at the same time.
|
| 110 |
+
#
|
| 111 |
+
# We use post-order traversal to collect all the proxies callbacks needed, construct
|
| 112 |
+
# the error message callbacks, and at the top-level traversal tree we execute all the callbacks.
|
| 113 |
+
# We need the callbacks because, in order to call the function to create a proxy for shape[0], we
|
| 114 |
+
# need the proxy for shape, which further requires the proxy for ret[1], etc.
|
| 115 |
+
|
| 116 |
+
def add_assertions(val):
|
| 117 |
+
call_backs: list[Callable] = []
|
| 118 |
+
messages: list[str] = []
|
| 119 |
+
if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
|
| 120 |
+
symbol = val.node.expr
|
| 121 |
+
if symbol in self.existing_inline_assertions:
|
| 122 |
+
return call_backs, messages
|
| 123 |
+
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(
|
| 124 |
+
symbol
|
| 125 |
+
):
|
| 126 |
+
if symbol in self._asserts_generated_unbacked_symbols:
|
| 127 |
+
return call_backs, messages
|
| 128 |
+
# We only care about unbacked symints for these inline
|
| 129 |
+
# constraints, which are prefixed with 'u'
|
| 130 |
+
constraint = self.range_constraints[symbol]
|
| 131 |
+
min_val, max_val = _convert_range_to_int(constraint)
|
| 132 |
+
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
| 133 |
+
call_backs.append(
|
| 134 |
+
partial(
|
| 135 |
+
self._assert_range_constraint,
|
| 136 |
+
lower=min_val,
|
| 137 |
+
upper=max_val,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
messages.append(assert_msg)
|
| 141 |
+
self._asserts_generated_unbacked_symbols.add(symbol)
|
| 142 |
+
|
| 143 |
+
elif isinstance(val, torch.Tensor):
|
| 144 |
+
for i, sym in enumerate(val.shape):
|
| 145 |
+
cbs, msgs = add_assertions(sym)
|
| 146 |
+
for cb, msg in zip(cbs, msgs):
|
| 147 |
+
|
| 148 |
+
def sym_size_cb(node, assert_msg, dim):
|
| 149 |
+
with node.graph.inserting_after(node):
|
| 150 |
+
dim_node = module.graph.call_function(
|
| 151 |
+
torch.ops.aten.sym_size.int,
|
| 152 |
+
(node, dim),
|
| 153 |
+
{},
|
| 154 |
+
)
|
| 155 |
+
cb(node=dim_node, assert_msg=assert_msg)
|
| 156 |
+
|
| 157 |
+
call_backs.append(partial(sym_size_cb, dim=i))
|
| 158 |
+
messages.append(f".shape[{i}]" + msg)
|
| 159 |
+
return call_backs, messages
|
| 160 |
+
|
| 161 |
+
callbacks, messages = add_assertions(val)
|
| 162 |
+
for cb, msg in zip(callbacks, messages):
|
| 163 |
+
cb(node=node, assert_msg=f"{node}" + msg)
|
| 164 |
+
|
| 165 |
+
module.recompile()
|
| 166 |
+
|
| 167 |
+
# Sometimes this pass would return a wrong graph where we have mismatched
|
| 168 |
+
# node names in signature. Before we fix it, let's just skip it.
|
| 169 |
+
if (
|
| 170 |
+
self.counter == 0
|
| 171 |
+
and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass
|
| 172 |
+
):
|
| 173 |
+
return PassResult(graph_module, False)
|
| 174 |
+
|
| 175 |
+
# Populate the stack trace with dummy vals to respect IR
|
| 176 |
+
for node in graph_module.graph.nodes:
|
| 177 |
+
if not node.meta.get("stack_trace", None) and node.op not in [
|
| 178 |
+
"placeholder",
|
| 179 |
+
"output",
|
| 180 |
+
]:
|
| 181 |
+
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
|
| 182 |
+
return PassResult(graph_module, True)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _get_existing_inline_assertions(
|
| 186 |
+
graph_module: torch.fx.GraphModule,
|
| 187 |
+
range_constraints: dict[sympy.Symbol, ValueRanges],
|
| 188 |
+
) -> dict[sympy.Symbol, ValueRanges]:
|
| 189 |
+
existing_inline_assertions: dict[sympy.Symbol, ValueRanges] = {}
|
| 190 |
+
|
| 191 |
+
for module in graph_module.modules():
|
| 192 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 193 |
+
continue
|
| 194 |
+
|
| 195 |
+
# Find all the existing inline assertions. They will look something like:
|
| 196 |
+
# %_local_scalar_dense = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%arg1_1,), kwargs = {})
|
| 197 |
+
# %ge = call_function[target=operator.ge](args = (%_local_scalar_dense, 0), kwargs = {})
|
| 198 |
+
# %_assert_scalar = call_function[target=torch.ops.aten._assert_scalar.default](args = (%scalar_tensor, "..."), kwargs = {})
|
| 199 |
+
for node in module.graph.nodes:
|
| 200 |
+
if node.target != torch.ops.aten._assert_scalar.default:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
compare_arg = node.args[0]
|
| 204 |
+
if not (
|
| 205 |
+
isinstance(compare_arg, torch.fx.Node)
|
| 206 |
+
and compare_arg.op == "call_function"
|
| 207 |
+
and compare_arg.target in (operator.le, operator.ge)
|
| 208 |
+
and len(compare_arg.args) == 2
|
| 209 |
+
):
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
compare_op = compare_arg.target
|
| 213 |
+
lhs, rhs = compare_arg.args
|
| 214 |
+
|
| 215 |
+
def maybe_get_symint(x):
|
| 216 |
+
if (
|
| 217 |
+
isinstance(x, torch.fx.Node)
|
| 218 |
+
and "val" in x.meta
|
| 219 |
+
and isinstance(x.meta["val"], torch.SymInt)
|
| 220 |
+
):
|
| 221 |
+
return x.meta["val"].node.expr
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
lhs = maybe_get_symint(lhs)
|
| 225 |
+
rhs = maybe_get_symint(rhs)
|
| 226 |
+
|
| 227 |
+
if compare_op is operator.ge:
|
| 228 |
+
lhs, rhs = rhs, lhs
|
| 229 |
+
|
| 230 |
+
if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int):
|
| 231 |
+
symint = lhs
|
| 232 |
+
scalar = rhs
|
| 233 |
+
elif isinstance(rhs, sympy.Symbol) and isinstance(lhs, int):
|
| 234 |
+
symint = rhs
|
| 235 |
+
scalar = lhs
|
| 236 |
+
else:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
if symint not in range_constraints:
|
| 240 |
+
raise RuntimeError(
|
| 241 |
+
f"Unable to find symint {symint} in {range_constraints}"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
previous_range = existing_inline_assertions.get(
|
| 245 |
+
symint, ValueRanges(-math.inf, math.inf)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if symint is lhs:
|
| 249 |
+
bounds = ValueRanges(-math.inf, scalar)
|
| 250 |
+
else:
|
| 251 |
+
bounds = ValueRanges(scalar, math.inf)
|
| 252 |
+
existing_inline_assertions[symint] = previous_range & bounds
|
| 253 |
+
|
| 254 |
+
return existing_inline_assertions
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/collect_tracepoints_pass.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import operator
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.export.exported_program import ConstantArgument, TensorArgument
|
| 9 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from torch.export.exported_program import ModuleCallSignature
|
| 14 |
+
from torch.export.graph_signature import ExportGraphSignature
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = ["CollectTracepointsPass"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CollectTracepointsPass(PassBase):
|
| 21 |
+
"""
|
| 22 |
+
Performs constant folding and constant propagation.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self, specs: dict[str, ModuleCallSignature], sig: ExportGraphSignature
|
| 27 |
+
) -> None:
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.specs = specs
|
| 30 |
+
self.sig = sig
|
| 31 |
+
|
| 32 |
+
def call(self, gm: torch.fx.GraphModule) -> PassResult | None:
|
| 33 |
+
def get_arg_spec(arg) -> TensorArgument | ConstantArgument:
|
| 34 |
+
if isinstance(arg, torch.fx.Node):
|
| 35 |
+
if isinstance(arg.meta.get("val"), torch.Tensor):
|
| 36 |
+
return TensorArgument(name=arg.name)
|
| 37 |
+
else:
|
| 38 |
+
raise AssertionError(
|
| 39 |
+
"Symint input is not implemented yet for submodule call signature."
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
return ConstantArgument(name="", value=arg)
|
| 43 |
+
|
| 44 |
+
for module in gm.modules():
|
| 45 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 46 |
+
continue
|
| 47 |
+
nn_module_stack = None
|
| 48 |
+
for node in module.graph.nodes:
|
| 49 |
+
if node.op != "call_function":
|
| 50 |
+
continue
|
| 51 |
+
if node.target is torch.ops.higher_order._export_tracepoint:
|
| 52 |
+
kind = node.kwargs["kind"]
|
| 53 |
+
if kind == "module_call_outputs":
|
| 54 |
+
nn_module_stack = node.meta["nn_module_stack"]
|
| 55 |
+
elif kind == "module_call_inputs":
|
| 56 |
+
nn_module_stack = None
|
| 57 |
+
else:
|
| 58 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 59 |
+
elif node.meta["nn_module_stack"] == nn_module_stack:
|
| 60 |
+
node.meta["nn_module_stack"].popitem()
|
| 61 |
+
else:
|
| 62 |
+
nn_module_stack = None
|
| 63 |
+
nn_module_stack = None
|
| 64 |
+
for node in reversed(module.graph.nodes):
|
| 65 |
+
if node.op != "call_function":
|
| 66 |
+
continue
|
| 67 |
+
if node.target is torch.ops.higher_order._export_tracepoint:
|
| 68 |
+
kind = node.kwargs["kind"]
|
| 69 |
+
if kind == "module_call_inputs":
|
| 70 |
+
nn_module_stack = node.meta["nn_module_stack"]
|
| 71 |
+
elif kind == "module_call_outputs":
|
| 72 |
+
nn_module_stack = None
|
| 73 |
+
else:
|
| 74 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 75 |
+
elif node.meta["nn_module_stack"] == nn_module_stack:
|
| 76 |
+
node.meta["nn_module_stack"].popitem()
|
| 77 |
+
else:
|
| 78 |
+
nn_module_stack = None
|
| 79 |
+
|
| 80 |
+
def copy_sig(sig) -> ModuleCallSignature:
|
| 81 |
+
from torch.export.exported_program import ModuleCallSignature
|
| 82 |
+
|
| 83 |
+
return ModuleCallSignature(
|
| 84 |
+
inputs=[],
|
| 85 |
+
outputs=[],
|
| 86 |
+
in_spec=sig.in_spec,
|
| 87 |
+
out_spec=sig.out_spec,
|
| 88 |
+
forward_arg_names=None,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
for module in gm.modules():
|
| 92 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 93 |
+
continue
|
| 94 |
+
for node in module.graph.nodes:
|
| 95 |
+
if node.op != "call_function":
|
| 96 |
+
continue
|
| 97 |
+
if node.target is torch.ops.higher_order._export_tracepoint:
|
| 98 |
+
# There's some subtlety worth noting. Here fqn corresponds to
|
| 99 |
+
# the call name, whereas path corresponds to the module name.
|
| 100 |
+
# They are not necessarily the same! When a submodule is shared
|
| 101 |
+
# through different aliases, there are as many _export_tracepoint
|
| 102 |
+
# markers as there are aliases, since the shared submodule is
|
| 103 |
+
# wrapped once for each alias.
|
| 104 |
+
path = node.kwargs["path"]
|
| 105 |
+
fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
|
| 106 |
+
|
| 107 |
+
module_key = next(reversed(node.meta["nn_module_stack"]))
|
| 108 |
+
if "@" in module_key:
|
| 109 |
+
suffix = module_key.split("@")[-1]
|
| 110 |
+
path = f"{path}@{suffix}"
|
| 111 |
+
|
| 112 |
+
call_fqn = f"{fqn}@{suffix}"
|
| 113 |
+
if call_fqn not in self.specs:
|
| 114 |
+
self.specs[call_fqn] = copy_sig(self.specs[fqn])
|
| 115 |
+
fqn = call_fqn
|
| 116 |
+
|
| 117 |
+
kind = node.kwargs["kind"]
|
| 118 |
+
for i, arg in enumerate(node.args):
|
| 119 |
+
# We only update the signature of the alias used to call
|
| 120 |
+
# the submodule. Otherwise the signatures of all aliases
|
| 121 |
+
# would get conflated; the inputs/outputs of every call
|
| 122 |
+
# would be recorded in every other call as well.
|
| 123 |
+
if fqn == path:
|
| 124 |
+
if kind == "module_call_inputs":
|
| 125 |
+
self.specs[path].inputs.append(get_arg_spec(arg))
|
| 126 |
+
elif kind == "module_call_outputs":
|
| 127 |
+
self.specs[path].outputs.append(get_arg_spec(arg))
|
| 128 |
+
else:
|
| 129 |
+
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
| 130 |
+
if isinstance(arg, torch.fx.Node):
|
| 131 |
+
for user in node.users:
|
| 132 |
+
assert user.op == "call_function"
|
| 133 |
+
assert user.target is operator.getitem
|
| 134 |
+
assert isinstance(user.args[1], int)
|
| 135 |
+
if user.args[1] == i:
|
| 136 |
+
user.replace_all_uses_with(arg)
|
| 137 |
+
self.sig.replace_all_uses(user.name, arg.name)
|
| 138 |
+
break
|
| 139 |
+
users = list(node.users)
|
| 140 |
+
for user in users:
|
| 141 |
+
assert len(user.users) == 0
|
| 142 |
+
gm.graph.erase_node(user)
|
| 143 |
+
gm.graph.erase_node(node)
|
| 144 |
+
return PassResult(gm, True)
|
| 145 |
+
|
| 146 |
+
return None
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/constant_folding.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from collections.abc import Callable
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils._pytree as pytree
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
aten = torch.ops.aten
|
| 12 |
+
|
| 13 |
+
# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
|
| 14 |
+
# The use case and more information could be found at:
|
| 15 |
+
# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
|
| 16 |
+
META_TAG = "MODULE_TYPE"
|
| 17 |
+
MODULE_TAG = "_MAIN_MODULE"
|
| 18 |
+
CONST_MODULE_TAG = "_CONST_MODULE"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def replace_node_with_constant(gm, node, constant, name=None):
|
| 22 |
+
g = gm.graph
|
| 23 |
+
|
| 24 |
+
if name:
|
| 25 |
+
qualname = name
|
| 26 |
+
else:
|
| 27 |
+
if not hasattr(gm, "_frozen_param_count"):
|
| 28 |
+
gm._frozen_param_count = 0
|
| 29 |
+
i = gm._frozen_param_count
|
| 30 |
+
|
| 31 |
+
while True:
|
| 32 |
+
qualname = f"_frozen_param{i}"
|
| 33 |
+
if not hasattr(gm, qualname):
|
| 34 |
+
break
|
| 35 |
+
i += 1
|
| 36 |
+
|
| 37 |
+
gm._frozen_param_count = i + 1
|
| 38 |
+
|
| 39 |
+
with g.inserting_before(node):
|
| 40 |
+
new_input_node = g.create_node("get_attr", qualname, (), {})
|
| 41 |
+
node.replace_all_uses_with(new_input_node)
|
| 42 |
+
new_input_node.meta.update(node.meta)
|
| 43 |
+
g.erase_node(node)
|
| 44 |
+
|
| 45 |
+
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
|
| 46 |
+
gm.register_buffer(qualname, constant)
|
| 47 |
+
setattr(gm, qualname, constant)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ConstantFolder(torch.fx.Interpreter):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
gm: torch.fx.GraphModule,
|
| 54 |
+
skip_constructors: bool = False,
|
| 55 |
+
):
|
| 56 |
+
super().__init__(gm)
|
| 57 |
+
self.node_replacements: dict[torch.fx.Node, Any] = {}
|
| 58 |
+
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
|
| 59 |
+
self.unknown_value = object()
|
| 60 |
+
self.skip_constructors: bool = skip_constructors
|
| 61 |
+
|
| 62 |
+
# overwrite this to deallocate env values if their only remaining use
|
| 63 |
+
# is the output
|
| 64 |
+
self.user_to_last_uses = self.node_to_last_non_output_use()
|
| 65 |
+
|
| 66 |
+
def is_impure(self, node: torch.fx.Node) -> bool:
|
| 67 |
+
if (
|
| 68 |
+
node.target is torch.ops.prims.convert_element_type.default
|
| 69 |
+
and node.args[0].op == "get_attr" # type: ignore[union-attr]
|
| 70 |
+
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
|
| 71 |
+
and node.args[1] == torch.bfloat16
|
| 72 |
+
):
|
| 73 |
+
# For int8_weight -> dq -> bf16_weight
|
| 74 |
+
return True
|
| 75 |
+
if node.target in [
|
| 76 |
+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
|
| 77 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
| 78 |
+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
|
| 79 |
+
torch.ops.pt2e_quant.dequantize_affine,
|
| 80 |
+
]:
|
| 81 |
+
# For the pattern fp32_weight -> q -> dq
|
| 82 |
+
# We only folding fp32_weight -> q
|
| 83 |
+
# int8_weight and leave dq in graph to be fused
|
| 84 |
+
return True
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
def node_to_last_non_output_use(self):
|
| 88 |
+
last_non_output_use = collections.defaultdict(list)
|
| 89 |
+
seen_uses = set()
|
| 90 |
+
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
|
| 91 |
+
|
| 92 |
+
for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
|
| 93 |
+
if node.target == "output":
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
def add_use(inp):
|
| 97 |
+
if inp in seen_uses:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
seen_uses.add(inp)
|
| 101 |
+
last_non_output_use[node].append(inp)
|
| 102 |
+
|
| 103 |
+
# In-place is fine since we don't mutate
|
| 104 |
+
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
|
| 105 |
+
|
| 106 |
+
# if this node is only used in output, we want to gc it right away
|
| 107 |
+
if len(node.users) == 1 and output_node in node.users:
|
| 108 |
+
last_non_output_use[node].append(node)
|
| 109 |
+
|
| 110 |
+
return last_non_output_use
|
| 111 |
+
|
| 112 |
+
def run_node(self, node):
|
| 113 |
+
if node.target == "output":
|
| 114 |
+
# because we remove nodes from env on last non output use,
|
| 115 |
+
# re-define them now or we'll get error in interpreter
|
| 116 |
+
def set_env(arg):
|
| 117 |
+
self.env[arg] = self.unknown_value
|
| 118 |
+
|
| 119 |
+
# In-place is fine since we don't mutate
|
| 120 |
+
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
|
| 121 |
+
return super().run_node(node)
|
| 122 |
+
|
| 123 |
+
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
| 124 |
+
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
# We need to do this weird thing because in cases where flattened_inputs
|
| 127 |
+
# contains a ScriptObject, equality checking results in a type error if
|
| 128 |
+
# the types are different.
|
| 129 |
+
if any(
|
| 130 |
+
type(self.unknown_value) is type(input_) and self.unknown_value == input_
|
| 131 |
+
for input_ in flattened_inputs
|
| 132 |
+
):
|
| 133 |
+
return self.unknown_value
|
| 134 |
+
|
| 135 |
+
# TODO - fix errors with this
|
| 136 |
+
if (
|
| 137 |
+
node.op == "call_function"
|
| 138 |
+
and node.target is aten._efficientzerotensor.default
|
| 139 |
+
):
|
| 140 |
+
return self.unknown_value
|
| 141 |
+
|
| 142 |
+
# TODO - constant folding triton kernel returns the inputs -- fix this
|
| 143 |
+
if (
|
| 144 |
+
node.op == "call_function"
|
| 145 |
+
and node.name == "triton_kernel_wrapper_functional_proxy"
|
| 146 |
+
):
|
| 147 |
+
return self.unknown_value
|
| 148 |
+
|
| 149 |
+
# skip constructors, since inductor generates optimal code for them already
|
| 150 |
+
# and turning into tensor would result in an additional global memory read
|
| 151 |
+
# TODO - more complicated strategy
|
| 152 |
+
if (
|
| 153 |
+
self.skip_constructors
|
| 154 |
+
and node.op != "get_attr"
|
| 155 |
+
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
|
| 156 |
+
):
|
| 157 |
+
return self.unknown_value
|
| 158 |
+
|
| 159 |
+
# All mutations should either be removed or on inputs which we did not make constant
|
| 160 |
+
if (
|
| 161 |
+
isinstance(node.target, torch._ops.OpOverload)
|
| 162 |
+
and torch.Tag.nondeterministic_seeded in node.target.tags
|
| 163 |
+
):
|
| 164 |
+
return self.unknown_value
|
| 165 |
+
|
| 166 |
+
out = super().run_node(node)
|
| 167 |
+
|
| 168 |
+
if node.op != "get_attr" and isinstance(out, torch.Tensor):
|
| 169 |
+
if out.device.type == "meta":
|
| 170 |
+
return out
|
| 171 |
+
|
| 172 |
+
if not self.insertable_tensor_check(out):
|
| 173 |
+
return out
|
| 174 |
+
|
| 175 |
+
if self.is_impure(node):
|
| 176 |
+
return self.unknown_value
|
| 177 |
+
|
| 178 |
+
self.add_node_replacement(node, out)
|
| 179 |
+
|
| 180 |
+
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
|
| 181 |
+
|
| 182 |
+
for n in flattened_node_inps:
|
| 183 |
+
if not isinstance(n, torch.fx.Node):
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
self.replaced_uses[n] += 1
|
| 187 |
+
|
| 188 |
+
for to_delete in self.user_to_last_uses.get(node, []):
|
| 189 |
+
if self.replaced_uses[to_delete] == len(to_delete.users):
|
| 190 |
+
self.node_replacements.pop(to_delete, None)
|
| 191 |
+
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
|
| 195 |
+
return True
|
| 196 |
+
|
| 197 |
+
def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
|
| 198 |
+
self.node_replacements[node] = tensor
|
| 199 |
+
|
| 200 |
+
def run(self): # type: ignore[override]
|
| 201 |
+
env = {}
|
| 202 |
+
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
| 203 |
+
env[n] = self.unknown_value
|
| 204 |
+
return super().run(initial_env=env)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def constant_fold(
|
| 208 |
+
gm: torch.fx.GraphModule,
|
| 209 |
+
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
| 210 |
+
):
|
| 211 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 212 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 213 |
+
cf.run()
|
| 214 |
+
|
| 215 |
+
for node, constant in cf.node_replacements.items():
|
| 216 |
+
if constraint_fn is not None and not constraint_fn(node):
|
| 217 |
+
continue
|
| 218 |
+
replace_node_with_constant(gm, node, constant)
|
| 219 |
+
|
| 220 |
+
erased_params = []
|
| 221 |
+
# Get all attr users by looking up the graph instead from node.users, because in this case
|
| 222 |
+
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
|
| 223 |
+
|
| 224 |
+
# opcode name target args kwargs
|
| 225 |
+
# ------------- ------------------- ---------------- --------------------------- --------
|
| 226 |
+
# placeholder arg0_1 arg0 () {}
|
| 227 |
+
# get_attr _tensor_constant0 state () {}
|
| 228 |
+
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
|
| 229 |
+
# get_attr _tensor_constant0_1 state () {}
|
| 230 |
+
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
|
| 231 |
+
# output output output ([add],) {}
|
| 232 |
+
|
| 233 |
+
get_attr_node_users = defaultdict(list)
|
| 234 |
+
for node in gm.graph.nodes:
|
| 235 |
+
if node.op == "get_attr":
|
| 236 |
+
get_attr_node_users[node.target].extend(node.users.keys())
|
| 237 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 238 |
+
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
|
| 239 |
+
if hasattr(gm, node.target):
|
| 240 |
+
delattr(gm, node.target)
|
| 241 |
+
erased_params.append(node)
|
| 242 |
+
for node in erased_params:
|
| 243 |
+
gm.graph.erase_node(node)
|
| 244 |
+
|
| 245 |
+
gm.graph.eliminate_dead_code()
|
| 246 |
+
gm.graph.lint()
|
| 247 |
+
gm.recompile()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def constant_graph_tag(gm: torch.fx.GraphModule) -> None:
|
| 251 |
+
with torch.utils._python_dispatch._disable_current_modes():
|
| 252 |
+
cf = ConstantFolder(gm, skip_constructors=True)
|
| 253 |
+
cf.run()
|
| 254 |
+
|
| 255 |
+
for node in gm.graph.nodes:
|
| 256 |
+
if (
|
| 257 |
+
node.op == "get_attr"
|
| 258 |
+
or node in cf.node_replacements
|
| 259 |
+
or node in cf.replaced_uses
|
| 260 |
+
):
|
| 261 |
+
node.meta[META_TAG] = CONST_MODULE_TAG
|
| 262 |
+
else:
|
| 263 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 267 |
+
"""
|
| 268 |
+
Construct a GraphModule which corresponds to the part which could be
|
| 269 |
+
constant folded in provided gm.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
constant_graph_tag(gm)
|
| 273 |
+
# We rewrite the tags, if it's a constant being directly consumed, without
|
| 274 |
+
# any folding opportunity, we keep it in main gm.
|
| 275 |
+
for node in gm.graph.find_nodes(op="get_attr"):
|
| 276 |
+
used_to_fold = False
|
| 277 |
+
for u in node.users:
|
| 278 |
+
if u.meta[META_TAG] == CONST_MODULE_TAG:
|
| 279 |
+
used_to_fold = True
|
| 280 |
+
break
|
| 281 |
+
if not used_to_fold:
|
| 282 |
+
node.meta[META_TAG] = MODULE_TAG
|
| 283 |
+
|
| 284 |
+
new_graph = torch.fx.Graph()
|
| 285 |
+
|
| 286 |
+
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
| 287 |
+
output_nodes = []
|
| 288 |
+
for node in gm.graph.nodes:
|
| 289 |
+
if node.meta[META_TAG] == MODULE_TAG:
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
|
| 293 |
+
node_remapping[node] = new_node
|
| 294 |
+
|
| 295 |
+
for user in node.users:
|
| 296 |
+
if user.meta[META_TAG] == MODULE_TAG:
|
| 297 |
+
output_nodes.append(new_node)
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
new_graph.output(tuple(output_nodes))
|
| 301 |
+
new_graph.lint()
|
| 302 |
+
new_gm = torch.fx.GraphModule(gm, new_graph)
|
| 303 |
+
|
| 304 |
+
return new_gm
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/functionalize_side_effectful_ops_pass.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._export.pass_base import (
|
| 6 |
+
_ExportPassBaseDeprecatedDoNotUse,
|
| 7 |
+
Argument,
|
| 8 |
+
PassResult,
|
| 9 |
+
)
|
| 10 |
+
from torch._export.pass_infra.node_metadata import NodeMetadata
|
| 11 |
+
from torch._export.pass_infra.proxy_value import ProxyValue
|
| 12 |
+
from torch._ops import OpOverload
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
aten = torch.ops.aten
|
| 16 |
+
|
| 17 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
|
| 18 |
+
aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
|
| 19 |
+
aten._assert_async.msg: aten._functional_assert_async.msg,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class _FunctionalizeSideEffectfulOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 24 |
+
"""
|
| 25 |
+
Functionalize ops with side effect in graph module by replacing the op with
|
| 26 |
+
functional version of it. A new dependency token (`dep_token`) will be
|
| 27 |
+
created and propagated through functional ops to output.
|
| 28 |
+
For example:
|
| 29 |
+
```
|
| 30 |
+
def f(x):
|
| 31 |
+
sym_constrain_range(x.shape[0], min=1, max=3)
|
| 32 |
+
return x.add(3)
|
| 33 |
+
```
|
| 34 |
+
Will be transformed to:
|
| 35 |
+
```
|
| 36 |
+
def f(x):
|
| 37 |
+
dep_token0 = _make_dep_token()
|
| 38 |
+
dep_token1 = _functional_sym_constrain_range(
|
| 39 |
+
x.shape[0], min=1, max=3, dep_token=dep_token0
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return x.add(3), dep_token1
|
| 43 |
+
```
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self._dep_token: Optional[ProxyValue] = None
|
| 49 |
+
self._next_dep_token_index: Optional[int] = None
|
| 50 |
+
|
| 51 |
+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
| 52 |
+
# Early return if no non-functional assertions.
|
| 53 |
+
if not any(
|
| 54 |
+
n.target in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS
|
| 55 |
+
for n in graph_module.graph.nodes
|
| 56 |
+
):
|
| 57 |
+
return PassResult(graph_module=graph_module, modified=False)
|
| 58 |
+
|
| 59 |
+
gm = copy.deepcopy(graph_module)
|
| 60 |
+
self._dep_token = None
|
| 61 |
+
self._next_dep_token_index = None
|
| 62 |
+
return super().call(gm)
|
| 63 |
+
|
| 64 |
+
def call_operator(
|
| 65 |
+
self,
|
| 66 |
+
op: OpOverload,
|
| 67 |
+
args: tuple[Argument, ...],
|
| 68 |
+
kwargs: dict[str, Argument],
|
| 69 |
+
meta: NodeMetadata,
|
| 70 |
+
) -> ProxyValue:
|
| 71 |
+
if op not in _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS:
|
| 72 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 73 |
+
|
| 74 |
+
if self._dep_token is None:
|
| 75 |
+
self._dep_token = super().call_operator(
|
| 76 |
+
aten._make_dep_token,
|
| 77 |
+
args=(),
|
| 78 |
+
kwargs={},
|
| 79 |
+
meta=self._create_dummy_node_metadata(),
|
| 80 |
+
)
|
| 81 |
+
self._dep_token.node.name = "dep_token0"
|
| 82 |
+
self._next_dep_token_index = 1
|
| 83 |
+
|
| 84 |
+
self._dep_token = super().call_operator(
|
| 85 |
+
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS[op],
|
| 86 |
+
args=args,
|
| 87 |
+
kwargs={**kwargs, "dep_token": self._dep_token},
|
| 88 |
+
meta=meta,
|
| 89 |
+
)
|
| 90 |
+
assert self._next_dep_token_index is not None
|
| 91 |
+
self._dep_token.node.name = f"dep_token{self._next_dep_token_index}"
|
| 92 |
+
self._next_dep_token_index += 1
|
| 93 |
+
|
| 94 |
+
return self._dep_token
|
| 95 |
+
|
| 96 |
+
def output(self, results: list[Argument], meta: NodeMetadata) -> ProxyValue:
|
| 97 |
+
assert self._dep_token is not None
|
| 98 |
+
|
| 99 |
+
return super().output(results=(*results, self._dep_token), meta=meta) # type: ignore[arg-type]
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/insert_custom_op_guards.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._export.passes._node_metadata_hook import (
|
| 6 |
+
_node_metadata_hook,
|
| 7 |
+
_set_node_metadata_hook,
|
| 8 |
+
)
|
| 9 |
+
from torch._library.fake_profile import OpProfile, TensorMetadata
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def insert_custom_op_guards(gm: torch.fx.GraphModule, ops_to_guard: set[str]) -> None:
|
| 13 |
+
"""
|
| 14 |
+
This is used by draft_export to insert guards in front of calls to custom
|
| 15 |
+
operators which have a generated fake kernel.
|
| 16 |
+
"""
|
| 17 |
+
for node in gm.graph.nodes:
|
| 18 |
+
if node.op == "call_function" and str(node.target) in ops_to_guard:
|
| 19 |
+
with (
|
| 20 |
+
_set_node_metadata_hook(
|
| 21 |
+
gm,
|
| 22 |
+
functools.partial(
|
| 23 |
+
_node_metadata_hook,
|
| 24 |
+
metadata={"stack_trace": node.meta.get("stack_trace")},
|
| 25 |
+
),
|
| 26 |
+
),
|
| 27 |
+
gm.graph.inserting_before(node),
|
| 28 |
+
):
|
| 29 |
+
for arg in (*node.args, *node.kwargs.values()):
|
| 30 |
+
if isinstance(arg, torch.fx.Node) and isinstance(
|
| 31 |
+
arg.meta.get("val"), torch.Tensor
|
| 32 |
+
):
|
| 33 |
+
val = arg.meta["val"]
|
| 34 |
+
gm.graph.call_function(
|
| 35 |
+
torch.ops.aten._assert_tensor_metadata.default,
|
| 36 |
+
args=(arg,),
|
| 37 |
+
kwargs={
|
| 38 |
+
"dtype": val.dtype,
|
| 39 |
+
"device": val.device,
|
| 40 |
+
"layout": val.layout,
|
| 41 |
+
},
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
gm.recompile()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_op_profiles(
|
| 48 |
+
gm: torch.fx.GraphModule, ops_to_guard: set[str]
|
| 49 |
+
) -> dict[str, set[OpProfile]]:
|
| 50 |
+
"""
|
| 51 |
+
This is used by draft_export to get a list of custom operator profiles so
|
| 52 |
+
that we can generate fake kernels.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def _get_op_profile(node: torch.fx.Node) -> OpProfile:
|
| 56 |
+
args_profile = tuple(
|
| 57 |
+
TensorMetadata.maybe_from_tensor(arg.meta.get("val"))
|
| 58 |
+
if isinstance(arg, torch.fx.Node)
|
| 59 |
+
else None
|
| 60 |
+
for arg in (*node.args, *node.kwargs.values())
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
out_profile = None
|
| 64 |
+
meta = node.meta.get("val")
|
| 65 |
+
assert meta is not None
|
| 66 |
+
if isinstance(meta, torch.Tensor):
|
| 67 |
+
out_profile = TensorMetadata.maybe_from_tensor(meta)
|
| 68 |
+
elif isinstance(meta, (list, tuple)):
|
| 69 |
+
out_profile = tuple(TensorMetadata.maybe_from_tensor(m) for m in meta) # type: ignore[assignment]
|
| 70 |
+
assert out_profile is not None
|
| 71 |
+
|
| 72 |
+
return OpProfile(args_profile, out_profile) # type: ignore[arg-type]
|
| 73 |
+
|
| 74 |
+
op_profiles: dict[str, set[OpProfile]] = defaultdict(set)
|
| 75 |
+
|
| 76 |
+
for node in gm.graph.nodes:
|
| 77 |
+
if node.op == "call_function" and str(node.target) in ops_to_guard:
|
| 78 |
+
op_profiles[str(node.target)].add(_get_op_profile(node))
|
| 79 |
+
|
| 80 |
+
return op_profiles
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/lift_constants_pass.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._export.verifier import SpecViolationError
|
| 8 |
+
from torch._guards import detect_fake_mode
|
| 9 |
+
from torch._library.fake_class_registry import FakeScriptObject
|
| 10 |
+
from torch._library.opaque_object import is_opaque_reference_type
|
| 11 |
+
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
| 12 |
+
from torch.export.exported_program import (
|
| 13 |
+
ArgumentSpec,
|
| 14 |
+
CustomObjArgument,
|
| 15 |
+
ExportGraphSignature,
|
| 16 |
+
InputKind,
|
| 17 |
+
InputSpec,
|
| 18 |
+
TensorArgument,
|
| 19 |
+
)
|
| 20 |
+
from torch.fx._symbolic_trace import _ConstantAttributeType
|
| 21 |
+
from torch.fx.graph_module import _get_attr
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
log = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ConstantAttrMap(collections.abc.MutableMapping):
|
| 28 |
+
"""A mapping class that understands how to use module constants (tensors,
|
| 29 |
+
ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally,
|
| 30 |
+
but ScriptObjects are stored by hash, because different torch.ScriptObjects can point to
|
| 31 |
+
the same underlying value (but we guarantee that they will `hash()` to the same value
|
| 32 |
+
if that's the case).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self) -> None:
|
| 36 |
+
# Underlying dict that we use to implement this mapping.
|
| 37 |
+
self._constant_attrs: dict[
|
| 38 |
+
Union[int, torch.Tensor, FakeScriptObject, torch.utils._pytree.TreeSpec],
|
| 39 |
+
list[Any],
|
| 40 |
+
] = {}
|
| 41 |
+
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for
|
| 42 |
+
# APIs like `__iter__` that should look like they're returning the
|
| 43 |
+
# original ScriptObjects.
|
| 44 |
+
self._script_object_map: dict[int, torch.ScriptObject] = {}
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, key: _ConstantAttributeType) -> Any:
|
| 47 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 48 |
+
assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
|
| 49 |
+
return self._constant_attrs[real_key]
|
| 50 |
+
|
| 51 |
+
def __setitem__(self, key: _ConstantAttributeType, value):
|
| 52 |
+
# we shouldn't actually call this, should go to add() instead to handle aliasing
|
| 53 |
+
raise NotImplementedError(
|
| 54 |
+
"""Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
|
| 55 |
+
The same key can be mapped to multiple values, for handling constant aliasing."""
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def add(self, key: _ConstantAttributeType, value: Any) -> None:
|
| 59 |
+
if isinstance(key, torch.ScriptObject):
|
| 60 |
+
if hash(key) not in self._constant_attrs:
|
| 61 |
+
self._constant_attrs[hash(key)] = []
|
| 62 |
+
self._constant_attrs[hash(key)].append(value)
|
| 63 |
+
self._script_object_map[hash(key)] = key
|
| 64 |
+
elif isinstance(key, (torch.Tensor, FakeScriptObject)):
|
| 65 |
+
if key not in self._constant_attrs:
|
| 66 |
+
self._constant_attrs[key] = []
|
| 67 |
+
self._constant_attrs[key].append(value)
|
| 68 |
+
else:
|
| 69 |
+
raise TypeError(
|
| 70 |
+
f"Expected key to be a tensor or ScriptObject, got {type(key)}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def __delitem__(self, key: _ConstantAttributeType):
|
| 74 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 75 |
+
|
| 76 |
+
del self._constant_attrs[real_key]
|
| 77 |
+
|
| 78 |
+
def __iter__(self):
|
| 79 |
+
for key in self._constant_attrs:
|
| 80 |
+
if isinstance(key, int):
|
| 81 |
+
yield self._script_object_map[key]
|
| 82 |
+
else:
|
| 83 |
+
yield key
|
| 84 |
+
|
| 85 |
+
def __len__(self):
|
| 86 |
+
return len(self._constant_attrs)
|
| 87 |
+
|
| 88 |
+
def __contains__(self, key: object) -> bool:
|
| 89 |
+
real_key = hash(key) if isinstance(key, torch.ScriptObject) else key
|
| 90 |
+
return real_key in self._constant_attrs
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
|
| 94 |
+
# The FQN of the constant tensor in the state dict should
|
| 95 |
+
# correspond to the module where the constant tensor was
|
| 96 |
+
# originally used.
|
| 97 |
+
if len(node.meta["nn_module_stack"]) == 0:
|
| 98 |
+
return constant_name
|
| 99 |
+
parent_fqn = list(node.meta["nn_module_stack"].values())[-1][0]
|
| 100 |
+
if len(parent_fqn) > 0:
|
| 101 |
+
return f"{parent_fqn}.{constant_name}"
|
| 102 |
+
else:
|
| 103 |
+
return constant_name
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _get_first_fqn(
|
| 107 |
+
const_attrs: ConstantAttrMap,
|
| 108 |
+
key: _ConstantAttributeType,
|
| 109 |
+
) -> Any:
|
| 110 |
+
fqns = const_attrs.get(key)
|
| 111 |
+
return fqns[0] if fqns else None
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _unused_constant(node: torch.fx.Node) -> Optional[list[torch.fx.Node]]:
|
| 115 |
+
"""
|
| 116 |
+
If there is a tensor constant created while tracing, here is how the graph
|
| 117 |
+
looks like:
|
| 118 |
+
|
| 119 |
+
%_tensor_constant0 : [num_users=1] = get_attr[target=_tensor_constant0]
|
| 120 |
+
%lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,))
|
| 121 |
+
%detach_ : [num_users=?] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,))
|
| 122 |
+
|
| 123 |
+
To check to see if the tensor constant is being used, we want to traverse to
|
| 124 |
+
the detach node to see if it's actually being used.
|
| 125 |
+
|
| 126 |
+
This function returns None if this constant is being used, otherwise it returns the
|
| 127 |
+
lift_fresh and detach node to be removed later.
|
| 128 |
+
""" # noqa: B950
|
| 129 |
+
if len(node.users) > 1:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
lift_fresh_node = next(iter(node.users.keys()))
|
| 133 |
+
if not (
|
| 134 |
+
lift_fresh_node.op == "call_function"
|
| 135 |
+
and lift_fresh_node.target
|
| 136 |
+
in (
|
| 137 |
+
torch.ops.aten.lift_fresh.default,
|
| 138 |
+
torch.ops.aten.lift_fresh_copy.default,
|
| 139 |
+
)
|
| 140 |
+
):
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
if len(lift_fresh_node.users) > 1:
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
# Case 1: lift node is not used anywhere
|
| 147 |
+
if len(lift_fresh_node.users) == 0:
|
| 148 |
+
return [lift_fresh_node, node]
|
| 149 |
+
|
| 150 |
+
detach_node = next(iter(lift_fresh_node.users.keys()))
|
| 151 |
+
if not (
|
| 152 |
+
detach_node.op == "call_function"
|
| 153 |
+
and detach_node.target
|
| 154 |
+
in (
|
| 155 |
+
torch.ops.aten.detach_.default,
|
| 156 |
+
torch.ops.aten.detach.default,
|
| 157 |
+
)
|
| 158 |
+
):
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
if len(detach_node.users) > 0:
|
| 162 |
+
return None
|
| 163 |
+
else:
|
| 164 |
+
# Case 2: Lift node's child is not used anywhere
|
| 165 |
+
return [detach_node, lift_fresh_node, node]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def lift_constants_pass(
|
| 169 |
+
gm: torch.fx.GraphModule,
|
| 170 |
+
graph_signature: ExportGraphSignature,
|
| 171 |
+
constant_attrs: ConstantAttrMap,
|
| 172 |
+
) -> dict[str, _ConstantAttributeType]:
|
| 173 |
+
"""
|
| 174 |
+
Takes a graph module, graph signature, and modifies them inplace to lift any
|
| 175 |
+
constants (tensors or custom classes) as inputs to the graph. Returns a
|
| 176 |
+
dictionary of names to constants.
|
| 177 |
+
|
| 178 |
+
Arguments:
|
| 179 |
+
gm (torch.fx.GraphModule): The graph module containing the graph and constants to lift.
|
| 180 |
+
graph_signature (ExportGraphSignature): This graph signature will be
|
| 181 |
+
mutated to add additional CONSTANT_TENSOR and CUSTOM_OBJ inputs.
|
| 182 |
+
constant_attrs (ConstantAttr): A mapping from a constant value to its
|
| 183 |
+
fully-qualified path in `gm`. This is used to maintain consistent
|
| 184 |
+
location of constants between the original module and the exported
|
| 185 |
+
version.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
A dictionary of fqn => constant value.
|
| 189 |
+
"""
|
| 190 |
+
all_constants: dict[str, _ConstantAttributeType] = {}
|
| 191 |
+
|
| 192 |
+
input_specs = graph_signature.input_specs
|
| 193 |
+
num_custom_obj = sum(
|
| 194 |
+
input_spec.kind == InputKind.CUSTOM_OBJ for input_spec in input_specs
|
| 195 |
+
)
|
| 196 |
+
num_tensor_constants = sum(
|
| 197 |
+
input_spec.kind == InputKind.CONSTANT_TENSOR for input_spec in input_specs
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
fake_mode = detect_fake_mode(
|
| 201 |
+
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes))
|
| 205 |
+
used_target_names = set()
|
| 206 |
+
|
| 207 |
+
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
| 208 |
+
assert len(input_nodes) == len(input_specs)
|
| 209 |
+
for i, (node, input_spec) in enumerate(zip(input_nodes, input_specs)):
|
| 210 |
+
used_target_names.add(input_spec.target)
|
| 211 |
+
if input_spec.kind == InputKind.USER_INPUT:
|
| 212 |
+
first_user_input = node
|
| 213 |
+
first_user_input_loc = i
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
lifted_objs = ConstantAttrMap()
|
| 217 |
+
renamed_targets = {}
|
| 218 |
+
for node in list(gm.graph.nodes):
|
| 219 |
+
if node.op == "get_attr":
|
| 220 |
+
if nodes_to_remove := _unused_constant(node):
|
| 221 |
+
# Remove the node if it's not being used
|
| 222 |
+
for node_rm in nodes_to_remove:
|
| 223 |
+
gm.graph.erase_node(node_rm)
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
constant_val = _get_attr(gm, node.target)
|
| 227 |
+
# These are not hashable and not gonna be lifted
|
| 228 |
+
# so we can skip them earlier
|
| 229 |
+
if isinstance(constant_val, torch.fx.GraphModule):
|
| 230 |
+
continue
|
| 231 |
+
if "LoweredBackendModule" in type(constant_val).__name__:
|
| 232 |
+
continue
|
| 233 |
+
if "AOTInductorRunnerWrapper" in type(constant_val).__name__:
|
| 234 |
+
continue
|
| 235 |
+
if isinstance(constant_val, torch.utils._pytree.TreeSpec):
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
if constant_val in lifted_objs:
|
| 239 |
+
# We already lifted this constant elsewhere. Just rewrite uses
|
| 240 |
+
# of this get_attr to point to the already-existing placeholder
|
| 241 |
+
# node.
|
| 242 |
+
const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
|
| 243 |
+
node.replace_all_uses_with(const_placeholder_node)
|
| 244 |
+
gm.graph.erase_node(node)
|
| 245 |
+
renamed_targets[node.name] = const_placeholder_node.name
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
# For ScriptObject, Tensor and FakeScriptObject constants:
|
| 249 |
+
# First check if the constant was an attribute on some module by
|
| 250 |
+
# consulting `constant_attrs` map. If it is, use the fqn that keeps
|
| 251 |
+
# its location consistent with the eager module.
|
| 252 |
+
#
|
| 253 |
+
# If it's not in the `constant_attrs` map, that means it's an inline
|
| 254 |
+
# constant (e.g. x + torch.tensor(0)), and thus did not have a
|
| 255 |
+
# specific location in the eager module. In that case, just generate
|
| 256 |
+
# some name and attach it to the module in which it was used.
|
| 257 |
+
if isinstance(
|
| 258 |
+
constant_val, (torch.ScriptObject, FakeScriptObject)
|
| 259 |
+
) or is_opaque_reference_type(type(constant_val)):
|
| 260 |
+
constant_kind = InputKind.CUSTOM_OBJ
|
| 261 |
+
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
|
| 262 |
+
if constant_fqn is not None:
|
| 263 |
+
constant_name = constant_fqn.replace(".", "_")
|
| 264 |
+
else:
|
| 265 |
+
constant_name = f"lifted_custom_{num_custom_obj}"
|
| 266 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 267 |
+
while constant_fqn in used_target_names:
|
| 268 |
+
num_custom_obj += 1
|
| 269 |
+
constant_name = f"lifted_custom_{num_custom_obj}"
|
| 270 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 271 |
+
num_custom_obj += 1
|
| 272 |
+
elif isinstance(constant_val, torch.Tensor):
|
| 273 |
+
# Remove the parameterness of constant_val
|
| 274 |
+
if isinstance(constant_val, torch.nn.Parameter):
|
| 275 |
+
log.debug(
|
| 276 |
+
"%s created when tracing %s is a parameter. But "
|
| 277 |
+
"it's not registered with register_parameter(). export will treat it as a constant tensor",
|
| 278 |
+
str(node.target),
|
| 279 |
+
str(node.meta.get("stack_trace", "<unknown stack>")),
|
| 280 |
+
)
|
| 281 |
+
# We get the real data out of the parameter by disabling the surrounding fake mode.
|
| 282 |
+
with unset_fake_temporarily():
|
| 283 |
+
constant_val = constant_val.data
|
| 284 |
+
constant_kind = InputKind.CONSTANT_TENSOR
|
| 285 |
+
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
|
| 286 |
+
if constant_fqn is not None:
|
| 287 |
+
constant_name = constant_fqn.replace(".", "_")
|
| 288 |
+
else:
|
| 289 |
+
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
| 290 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 291 |
+
while constant_fqn in used_target_names:
|
| 292 |
+
num_tensor_constants += 1
|
| 293 |
+
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
| 294 |
+
constant_fqn = get_constant_fqn(node, constant_name)
|
| 295 |
+
num_tensor_constants += 1
|
| 296 |
+
else:
|
| 297 |
+
raise SpecViolationError(
|
| 298 |
+
f"getattr node {node} referencing unsupported type {type(constant_val)}"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
with gm.graph.inserting_before(first_user_input):
|
| 302 |
+
# Insert the constant node before the first user input
|
| 303 |
+
const_placeholder_node = gm.graph.placeholder(constant_name)
|
| 304 |
+
# match target name with its node name in case there is name collision
|
| 305 |
+
# and suffix is added to node name in fx
|
| 306 |
+
const_placeholder_node.target = const_placeholder_node.name
|
| 307 |
+
|
| 308 |
+
for k, v in node.meta.items():
|
| 309 |
+
const_placeholder_node.meta[k] = v
|
| 310 |
+
|
| 311 |
+
# Once the FQN has been used, remove nn_module_stack, stack_trace
|
| 312 |
+
const_placeholder_node.meta.pop("nn_module_stack")
|
| 313 |
+
const_placeholder_node.meta.pop("stack_trace", None)
|
| 314 |
+
|
| 315 |
+
input_spec_arg: ArgumentSpec
|
| 316 |
+
if isinstance(constant_val, torch.Tensor):
|
| 317 |
+
if fake_mode is not None:
|
| 318 |
+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
|
| 319 |
+
constant_val, static_shapes=True
|
| 320 |
+
)
|
| 321 |
+
const_placeholder_node.meta["val"].constant = constant_val
|
| 322 |
+
else:
|
| 323 |
+
const_placeholder_node.meta["val"] = constant_val
|
| 324 |
+
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
|
| 325 |
+
elif isinstance(constant_val, torch._C.ScriptObject):
|
| 326 |
+
class_fqn = constant_val._type().qualified_name() # type: ignore[attr-defined]
|
| 327 |
+
const_placeholder_node.meta["val"] = CustomObjArgument(
|
| 328 |
+
constant_fqn, class_fqn
|
| 329 |
+
)
|
| 330 |
+
input_spec_arg = CustomObjArgument(
|
| 331 |
+
name=const_placeholder_node.name, class_fqn=class_fqn
|
| 332 |
+
)
|
| 333 |
+
elif isinstance(constant_val, FakeScriptObject):
|
| 334 |
+
class_fqn = constant_val.script_class_name
|
| 335 |
+
const_placeholder_node.meta["val"] = CustomObjArgument(
|
| 336 |
+
constant_fqn, class_fqn, constant_val
|
| 337 |
+
)
|
| 338 |
+
input_spec_arg = CustomObjArgument(
|
| 339 |
+
name=const_placeholder_node.name,
|
| 340 |
+
class_fqn=class_fqn,
|
| 341 |
+
fake_val=constant_val,
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
raise SpecViolationError(
|
| 345 |
+
f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
lifted_objs.add(constant_val, const_placeholder_node)
|
| 349 |
+
node.replace_all_uses_with(const_placeholder_node)
|
| 350 |
+
gm.graph.erase_node(node)
|
| 351 |
+
|
| 352 |
+
renamed_targets[node.name] = const_placeholder_node.name
|
| 353 |
+
|
| 354 |
+
# Add the constant as a buffer to the graph signature
|
| 355 |
+
graph_signature.input_specs.insert(
|
| 356 |
+
first_user_input_loc,
|
| 357 |
+
InputSpec(
|
| 358 |
+
kind=constant_kind,
|
| 359 |
+
arg=input_spec_arg,
|
| 360 |
+
target=constant_fqn,
|
| 361 |
+
),
|
| 362 |
+
)
|
| 363 |
+
if constant_val in constant_attrs:
|
| 364 |
+
for fqn in constant_attrs[constant_val]:
|
| 365 |
+
all_constants[fqn] = constant_val
|
| 366 |
+
else:
|
| 367 |
+
all_constants[constant_fqn] = constant_val
|
| 368 |
+
first_user_input_loc += 1
|
| 369 |
+
|
| 370 |
+
for spec in graph_signature.output_specs:
|
| 371 |
+
if spec.arg.name in renamed_targets:
|
| 372 |
+
spec.arg.name = renamed_targets[spec.arg.name]
|
| 373 |
+
|
| 374 |
+
return all_constants
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def rewrite_script_object_meta(
|
| 378 |
+
gm: torch.fx.GraphModule,
|
| 379 |
+
) -> dict[str, _ConstantAttributeType]:
|
| 380 |
+
"""When tracing, we produce a graph with FakeScriptObject in the
|
| 381 |
+
meta["val"].
|
| 382 |
+
|
| 383 |
+
For now, we rewrie meta["val"] to be a placeholder CustomObjArgument
|
| 384 |
+
"""
|
| 385 |
+
constants: dict[
|
| 386 |
+
str,
|
| 387 |
+
_ConstantAttributeType,
|
| 388 |
+
] = {}
|
| 389 |
+
for node in gm.graph.nodes:
|
| 390 |
+
if "val" not in node.meta:
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
old_meta = node.meta["val"]
|
| 394 |
+
|
| 395 |
+
if isinstance(old_meta, torch.ScriptObject):
|
| 396 |
+
class_fqn = old_meta._type().qualified_name() # type: ignore[attr-defined]
|
| 397 |
+
new_meta = CustomObjArgument(node.name, class_fqn)
|
| 398 |
+
constants[node.name] = old_meta
|
| 399 |
+
node.meta["val"] = new_meta
|
| 400 |
+
|
| 401 |
+
elif isinstance(old_meta, FakeScriptObject):
|
| 402 |
+
class_fqn = old_meta.script_class_name # type: ignore[attr-defined]
|
| 403 |
+
new_meta = CustomObjArgument(node.name, class_fqn, old_meta)
|
| 404 |
+
constants[node.name] = old_meta
|
| 405 |
+
node.meta["val"] = new_meta
|
| 406 |
+
|
| 407 |
+
return constants
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _materialize_and_lift_constants(
|
| 411 |
+
gm: torch.fx.GraphModule,
|
| 412 |
+
export_graph_signature: ExportGraphSignature,
|
| 413 |
+
constant_attrs: ConstantAttrMap,
|
| 414 |
+
) -> dict[str, _ConstantAttributeType]:
|
| 415 |
+
constants = rewrite_script_object_meta(gm)
|
| 416 |
+
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
|
| 417 |
+
return constants
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/remove_runtime_assertions.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class _RemoveRuntimeAssertionsPass(PassBase):
|
| 6 |
+
"""
|
| 7 |
+
Remove runtime assertions inserted by the
|
| 8 |
+
_AddRuntimeAssertionsForInlineConstraintsPass.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
| 12 |
+
modified = False
|
| 13 |
+
for module in graph_module.modules():
|
| 14 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 15 |
+
continue
|
| 16 |
+
for node in module.graph.nodes:
|
| 17 |
+
if node.target in [
|
| 18 |
+
torch.ops.aten._assert_async.msg,
|
| 19 |
+
torch.ops.aten._assert_scalar.default,
|
| 20 |
+
torch.ops.aten.sym_constrain_range_for_size.default,
|
| 21 |
+
torch.ops.aten.sym_constrain_range.default,
|
| 22 |
+
torch.ops.aten._assert_tensor_metadata.default,
|
| 23 |
+
]:
|
| 24 |
+
assert_async_node = node
|
| 25 |
+
if len(assert_async_node.users) > 0:
|
| 26 |
+
continue
|
| 27 |
+
module.graph.erase_node(assert_async_node)
|
| 28 |
+
# the upstream scalar_tensor <- {le, ge} <- sym_size
|
| 29 |
+
# linear chain of nodes of nodes is removed by the
|
| 30 |
+
# downstream dead code elimination
|
| 31 |
+
modified = True
|
| 32 |
+
|
| 33 |
+
# We don't necessarily want to run DCE here because it could affect
|
| 34 |
+
# nodes that are in the module_call_graph attribute of the exported
|
| 35 |
+
# program. We will leave it to the pass caller to call DCE.
|
| 36 |
+
return PassResult(graph_module, modified)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_autocast_with_hop_pass.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._higher_order_ops.wrap import wrap_with_autocast
|
| 8 |
+
|
| 9 |
+
from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split
|
| 10 |
+
from .replace_with_hop_pass_util import (
|
| 11 |
+
_replace_with_hop_helper,
|
| 12 |
+
_replace_with_hop_pass_helper,
|
| 13 |
+
_sequential_split_and_maybe_inline_subgraphs_helper,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from torch.export.graph_signature import ExportGraphSignature
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _is_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
|
| 22 |
+
return (
|
| 23 |
+
node
|
| 24 |
+
and node.op == "call_function"
|
| 25 |
+
and node.target
|
| 26 |
+
in [
|
| 27 |
+
torch.amp.autocast_mode._enter_autocast,
|
| 28 |
+
torch.amp.autocast_mode._exit_autocast,
|
| 29 |
+
]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _is_enter_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
|
| 34 |
+
return (
|
| 35 |
+
node
|
| 36 |
+
and node.op == "call_function"
|
| 37 |
+
and node.target is torch.amp.autocast_mode._enter_autocast
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _is_exit_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool:
|
| 42 |
+
return (
|
| 43 |
+
node
|
| 44 |
+
and node.op == "call_function"
|
| 45 |
+
and node.target is torch.amp.autocast_mode._exit_autocast
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _is_autocast_sub_mod(node: torch.fx.Node) -> bool:
|
| 50 |
+
"""
|
| 51 |
+
Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`.
|
| 52 |
+
"""
|
| 53 |
+
if node.op == "call_module":
|
| 54 |
+
assert isinstance(node.target, str)
|
| 55 |
+
subgm = getattr(node.graph.owning_module, node.target)
|
| 56 |
+
first_non_ph = nodes_first(
|
| 57 |
+
subgm.graph.nodes, lambda node: node.op != "placeholder"
|
| 58 |
+
)
|
| 59 |
+
if (
|
| 60 |
+
first_non_ph
|
| 61 |
+
and first_non_ph.op == "call_function"
|
| 62 |
+
and first_non_ph.target is torch.amp.autocast_mode._enter_autocast
|
| 63 |
+
):
|
| 64 |
+
# TODO: check if current auto-cast type is the same as the args of
|
| 65 |
+
# _enter_autocast. If so, return False, i.e. do not create a submodule.
|
| 66 |
+
return True
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _check_valid_autocast_block(
|
| 71 |
+
enter_autocast_node: torch.fx.Node, exit_autocast_node: torch.fx.Node
|
| 72 |
+
) -> None:
|
| 73 |
+
assert _is_enter_autocast_node(enter_autocast_node)
|
| 74 |
+
assert _is_exit_autocast_node(exit_autocast_node)
|
| 75 |
+
assert exit_autocast_node.args[0] == enter_autocast_node
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _replace_with_hop(node: torch.fx.Node) -> None:
|
| 79 |
+
assert node.op == "call_module"
|
| 80 |
+
graph: torch.fx.Graph = node.graph
|
| 81 |
+
assert graph.owning_module is not None
|
| 82 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 83 |
+
assert isinstance(node.target, str)
|
| 84 |
+
sub_gm = getattr(gm, node.target)
|
| 85 |
+
sub_graph = sub_gm.graph
|
| 86 |
+
autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node)
|
| 87 |
+
if len(autocast_nodes) > 0:
|
| 88 |
+
assert len(autocast_nodes) > 1 # need at least an enter node and an exist node
|
| 89 |
+
enter_autocast_node = autocast_nodes[0]
|
| 90 |
+
exit_autocast_node = autocast_nodes[-1]
|
| 91 |
+
_check_valid_autocast_block(enter_autocast_node, exit_autocast_node)
|
| 92 |
+
|
| 93 |
+
_replace_with_hop_helper(node, enter_autocast_node, wrap_with_autocast)
|
| 94 |
+
sub_graph.erase_node(exit_autocast_node)
|
| 95 |
+
sub_graph.erase_node(enter_autocast_node)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 99 |
+
"""
|
| 100 |
+
split_autocast creates a new graph module that splits the input graph module into multiple submodules
|
| 101 |
+
based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module.
|
| 102 |
+
|
| 103 |
+
Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are split
|
| 104 |
+
into a submodule. Nested autocast regions are not split.
|
| 105 |
+
`_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well.
|
| 106 |
+
|
| 107 |
+
Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph
|
| 108 |
+
module. Nodes marked with the same number are grouped into the same submodule.
|
| 109 |
+
A # 0
|
| 110 |
+
enter_autocast # 1
|
| 111 |
+
B # 1
|
| 112 |
+
exit_autocast # 1
|
| 113 |
+
C # 2
|
| 114 |
+
enter_autocast # 3
|
| 115 |
+
D # 3
|
| 116 |
+
exit_autocast # 3
|
| 117 |
+
E # 4
|
| 118 |
+
"""
|
| 119 |
+
enter_autocast_node_stack: list[torch.fx.Node] = []
|
| 120 |
+
first_node_after_outer_most_exit: bool = False
|
| 121 |
+
|
| 122 |
+
def node_call_back(node: torch.fx.Node) -> bool:
|
| 123 |
+
nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit
|
| 124 |
+
increment_id = False
|
| 125 |
+
if first_node_after_outer_most_exit or (
|
| 126 |
+
len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node)
|
| 127 |
+
):
|
| 128 |
+
assert len(enter_autocast_node_stack) == 0
|
| 129 |
+
first_node_after_outer_most_exit = False
|
| 130 |
+
increment_id = True
|
| 131 |
+
if _is_enter_autocast_node(node):
|
| 132 |
+
enter_autocast_node_stack.append(node)
|
| 133 |
+
elif _is_exit_autocast_node(node):
|
| 134 |
+
assert len(enter_autocast_node_stack) > 0
|
| 135 |
+
last_enter_autocast_node = enter_autocast_node_stack.pop()
|
| 136 |
+
assert node.args[0] == last_enter_autocast_node
|
| 137 |
+
if len(enter_autocast_node_stack) == 0:
|
| 138 |
+
# next node should be in the next submodule since
|
| 139 |
+
# autocast block ends
|
| 140 |
+
first_node_after_outer_most_exit = True
|
| 141 |
+
return increment_id
|
| 142 |
+
|
| 143 |
+
return sequential_split(gm, node_call_back)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _sequential_split_and_maybe_inline_subgraphs(
|
| 147 |
+
gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
|
| 148 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 149 |
+
"""
|
| 150 |
+
Helper function for replace_autocast_with_hop_pass().
|
| 151 |
+
Split the graph module into multiple subgraphs based on the autocast nodes.
|
| 152 |
+
For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
|
| 153 |
+
back into the parent graph module.
|
| 154 |
+
Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered
|
| 155 |
+
as a subgraph.
|
| 156 |
+
"""
|
| 157 |
+
need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes)
|
| 158 |
+
if not need_replacing:
|
| 159 |
+
return gm, graph_signature
|
| 160 |
+
|
| 161 |
+
# split_autocast returns a new graph module that could have different output
|
| 162 |
+
# args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`.
|
| 163 |
+
new_gm = _split_autocast(gm)
|
| 164 |
+
|
| 165 |
+
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None:
|
| 166 |
+
if _is_autocast_sub_mod(node):
|
| 167 |
+
_replace_with_hop(node)
|
| 168 |
+
else:
|
| 169 |
+
assert node.op == "call_module"
|
| 170 |
+
assert isinstance(node.target, str)
|
| 171 |
+
node_inline_(node)
|
| 172 |
+
|
| 173 |
+
return _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 174 |
+
new_gm, graph_signature, _maybe_inline_or_replace_with_hop
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def replace_autocast_with_hop_pass(
|
| 179 |
+
gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
|
| 180 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 181 |
+
"""
|
| 182 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 183 |
+
then recursively call itself on each of the submodules.
|
| 184 |
+
"""
|
| 185 |
+
return _replace_with_hop_pass_helper(
|
| 186 |
+
gm,
|
| 187 |
+
graph_signature,
|
| 188 |
+
_sequential_split_and_maybe_inline_subgraphs,
|
| 189 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import logging
|
| 3 |
+
import operator
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.export._trace
|
| 8 |
+
from torch._ops import OpOverload
|
| 9 |
+
from torch.ao.quantization.fx._decomposed import (
|
| 10 |
+
dequantize_per_channel,
|
| 11 |
+
dequantize_per_tensor,
|
| 12 |
+
quantize_per_tensor,
|
| 13 |
+
)
|
| 14 |
+
from torch.ao.quantization.utils import calculate_qmin_qmax
|
| 15 |
+
from torch.fx.graph_module import _assign_attr
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
log = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Those values will need to be carried over multiple operators.
|
| 21 |
+
_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None
|
| 22 |
+
_SCALE: Optional[Union[float, torch.fx.Node]] = None
|
| 23 |
+
_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def int_to_valid_dtype(val: int) -> torch.dtype:
|
| 27 |
+
from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import.
|
| 28 |
+
|
| 29 |
+
if isinstance(val, torch.dtype):
|
| 30 |
+
return val
|
| 31 |
+
dtype = _TORCH_ENUM_TO_DTYPE[val]
|
| 32 |
+
if dtype == torch.quint8:
|
| 33 |
+
return torch.uint8
|
| 34 |
+
elif dtype == torch.qint8:
|
| 35 |
+
return torch.int8
|
| 36 |
+
return dtype
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node:
|
| 40 |
+
return gm.graph.call_function(int_to_valid_dtype, (val,))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def insert_quantized_node(
|
| 44 |
+
gm: torch.fx.GraphModule,
|
| 45 |
+
val_node: torch.fx.Node,
|
| 46 |
+
scale_node: Union[float, torch.fx.Node],
|
| 47 |
+
zero_point_node: Union[float, torch.fx.Node],
|
| 48 |
+
qmin_node: Union[float, int, torch.fx.Node],
|
| 49 |
+
qmax_node: Union[float, int, torch.fx.Node],
|
| 50 |
+
dtype_node: Union[torch.dtype, torch.fx.Node],
|
| 51 |
+
qscheme: Optional[torch.qscheme],
|
| 52 |
+
) -> torch.fx.Node:
|
| 53 |
+
return gm.graph.call_function(
|
| 54 |
+
quantize_per_tensor,
|
| 55 |
+
(
|
| 56 |
+
val_node,
|
| 57 |
+
scale_node,
|
| 58 |
+
zero_point_node,
|
| 59 |
+
qmin_node,
|
| 60 |
+
qmax_node,
|
| 61 |
+
dtype_node,
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_dequantized(
|
| 67 |
+
val: torch.Tensor,
|
| 68 |
+
scale: Union[float, torch.Tensor],
|
| 69 |
+
zero_point: Union[float, torch.Tensor],
|
| 70 |
+
qmin: Union[float, int],
|
| 71 |
+
qmax: Union[float, int],
|
| 72 |
+
dtype: torch.dtype,
|
| 73 |
+
axis: Optional[int],
|
| 74 |
+
qscheme: Optional[torch.qscheme],
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
if qscheme is torch.per_tensor_affine:
|
| 77 |
+
return dequantize_per_tensor(
|
| 78 |
+
val,
|
| 79 |
+
scale, # type: ignore[arg-type]
|
| 80 |
+
zero_point, # type: ignore[arg-type]
|
| 81 |
+
qmin, # type: ignore[arg-type]
|
| 82 |
+
qmax, # type: ignore[arg-type]
|
| 83 |
+
dtype,
|
| 84 |
+
)
|
| 85 |
+
elif qscheme is torch.per_channel_affine:
|
| 86 |
+
return dequantize_per_channel(
|
| 87 |
+
val,
|
| 88 |
+
scale, # type: ignore[arg-type]
|
| 89 |
+
zero_point, # type: ignore[arg-type]
|
| 90 |
+
axis, # type: ignore[arg-type]
|
| 91 |
+
qmin, # type: ignore[arg-type]
|
| 92 |
+
qmax, # type: ignore[arg-type]
|
| 93 |
+
dtype,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def insert_dequantized_node(
|
| 100 |
+
gm: torch.fx.GraphModule,
|
| 101 |
+
val_node: torch.fx.Node,
|
| 102 |
+
scale_node: Union[float, torch.fx.Node],
|
| 103 |
+
zero_point_node: Union[float, torch.fx.Node],
|
| 104 |
+
qmin_node: Union[float, int, torch.fx.Node],
|
| 105 |
+
qmax_node: Union[float, int, torch.fx.Node],
|
| 106 |
+
dtype_node: Union[torch.dtype, torch.fx.Node],
|
| 107 |
+
axis_node: Optional[Union[int, torch.fx.Node]],
|
| 108 |
+
qscheme: Optional[torch.qscheme],
|
| 109 |
+
) -> torch.fx.Node:
|
| 110 |
+
if qscheme is torch.per_tensor_affine:
|
| 111 |
+
return gm.graph.call_function(
|
| 112 |
+
dequantize_per_tensor,
|
| 113 |
+
(
|
| 114 |
+
val_node,
|
| 115 |
+
scale_node,
|
| 116 |
+
zero_point_node,
|
| 117 |
+
qmin_node,
|
| 118 |
+
qmax_node,
|
| 119 |
+
dtype_node,
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
elif qscheme is torch.per_channel_affine:
|
| 123 |
+
return gm.graph.call_function(
|
| 124 |
+
dequantize_per_channel,
|
| 125 |
+
(
|
| 126 |
+
val_node,
|
| 127 |
+
scale_node,
|
| 128 |
+
zero_point_node,
|
| 129 |
+
axis_node,
|
| 130 |
+
qmin_node,
|
| 131 |
+
qmax_node,
|
| 132 |
+
dtype_node,
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_qmin_qmax(dtype: torch.dtype) -> tuple[Union[int, float], Union[int, float]]:
|
| 140 |
+
return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def insert_qmin_qmax_node(
|
| 144 |
+
gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node]
|
| 145 |
+
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
| 146 |
+
q_min_max_node = gm.graph.call_function(
|
| 147 |
+
calculate_qmin_qmax, (None, None, False, dtype_node, False)
|
| 148 |
+
)
|
| 149 |
+
qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0))
|
| 150 |
+
qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1))
|
| 151 |
+
return qmin_node, qmax_node
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_script_object(
|
| 155 |
+
gm: torch.nn.Module, node: torch.fx.Node
|
| 156 |
+
) -> torch._C.ScriptObject:
|
| 157 |
+
assert isinstance(node, torch.fx.Node)
|
| 158 |
+
assert node.op == "get_attr"
|
| 159 |
+
attr_name = node.target
|
| 160 |
+
assert isinstance(attr_name, str)
|
| 161 |
+
|
| 162 |
+
mod = gm
|
| 163 |
+
for attr in attr_name.split("."):
|
| 164 |
+
mod = getattr(mod, attr)
|
| 165 |
+
assert isinstance(mod, torch._C.ScriptObject)
|
| 166 |
+
return mod
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 170 |
+
gm: torch.fx.GraphModule,
|
| 171 |
+
param_node: torch.fx.Node,
|
| 172 |
+
) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 173 |
+
"""Directly inline tensor from a get_attr fx node."""
|
| 174 |
+
mod = get_script_object(gm, param_node)
|
| 175 |
+
w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined]
|
| 176 |
+
w_attr_name, b_attr_name = (
|
| 177 |
+
f"dequantized_{param_node.target}_w",
|
| 178 |
+
f"dequantized_{param_node.target}_b",
|
| 179 |
+
)
|
| 180 |
+
return insert_weight_and_bias_get_attr_node(
|
| 181 |
+
gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 186 |
+
gm: torch.fx.GraphModule,
|
| 187 |
+
get_attr_to_weight_node: torch.fx.Node,
|
| 188 |
+
get_attr_to_bias_node: Optional[torch.fx.Node],
|
| 189 |
+
) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 190 |
+
assert isinstance(get_attr_to_weight_node.target, str)
|
| 191 |
+
w_qtensor = getattr(gm, get_attr_to_weight_node.target)
|
| 192 |
+
w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w"
|
| 193 |
+
|
| 194 |
+
if get_attr_to_bias_node is not None:
|
| 195 |
+
assert isinstance(get_attr_to_bias_node.target, str)
|
| 196 |
+
b_qtensor = getattr(gm, get_attr_to_bias_node.target)
|
| 197 |
+
b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b"
|
| 198 |
+
else:
|
| 199 |
+
b_qtensor, b_attr_name = None, ""
|
| 200 |
+
|
| 201 |
+
return insert_weight_and_bias_get_attr_node(
|
| 202 |
+
gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def insert_weight_and_bias_get_attr_node(
|
| 207 |
+
gm: torch.fx.GraphModule,
|
| 208 |
+
w_qtensor: torch.Tensor,
|
| 209 |
+
b_qtensor: Optional[torch.Tensor],
|
| 210 |
+
w_attr_name: str,
|
| 211 |
+
b_attr_name: str,
|
| 212 |
+
) -> tuple[torch.fx.Node, Optional[torch.fx.Node]]:
|
| 213 |
+
w_tensor = get_tensor_from_qtensor(w_qtensor)
|
| 214 |
+
_assign_attr(w_tensor, gm, w_attr_name)
|
| 215 |
+
w_tensor_attr = gm.graph.get_attr(w_attr_name)
|
| 216 |
+
|
| 217 |
+
if b_qtensor is not None:
|
| 218 |
+
b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False)
|
| 219 |
+
_assign_attr(b_tensor, gm, b_attr_name)
|
| 220 |
+
b_tensor_attr = gm.graph.get_attr(b_attr_name)
|
| 221 |
+
else:
|
| 222 |
+
b_tensor_attr = None
|
| 223 |
+
|
| 224 |
+
return w_tensor_attr, b_tensor_attr
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def get_tensor_from_qtensor(
|
| 228 |
+
qtensor: torch.Tensor, dequant: bool = True
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
# Manual conversion because qint8 is not used anymore.
|
| 231 |
+
if qtensor.dtype in [torch.qint8, torch.quint8]:
|
| 232 |
+
tensor = qtensor.int_repr()
|
| 233 |
+
else:
|
| 234 |
+
tensor = qtensor
|
| 235 |
+
|
| 236 |
+
# Weights need dequantization with scaling and zero_point adjustment, but
|
| 237 |
+
# bias does not need that.
|
| 238 |
+
if dequant:
|
| 239 |
+
qscheme = qtensor.qscheme()
|
| 240 |
+
if qscheme == torch.per_channel_affine:
|
| 241 |
+
scale, zero_point, axis = (
|
| 242 |
+
qtensor.q_per_channel_scales(),
|
| 243 |
+
qtensor.q_per_channel_zero_points(),
|
| 244 |
+
qtensor.q_per_channel_axis(),
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
scale, zero_point, axis = (
|
| 248 |
+
qtensor.q_scale(), # type: ignore[assignment]
|
| 249 |
+
qtensor.q_zero_point(), # type: ignore[assignment]
|
| 250 |
+
None,
|
| 251 |
+
)
|
| 252 |
+
dtype = tensor.dtype
|
| 253 |
+
qmin, qmax = get_qmin_qmax(dtype)
|
| 254 |
+
return get_dequantized(
|
| 255 |
+
tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme
|
| 256 |
+
)
|
| 257 |
+
return tensor
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def insert_fused_activation_node(
|
| 261 |
+
gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node
|
| 262 |
+
) -> torch.fx.Node:
|
| 263 |
+
if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]:
|
| 264 |
+
fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,))
|
| 265 |
+
return fx_node
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _conv1d_op_with_squeeze(
|
| 269 |
+
inp: torch.Tensor,
|
| 270 |
+
weight: torch.Tensor,
|
| 271 |
+
bias: Optional[torch.Tensor],
|
| 272 |
+
stride: list[int],
|
| 273 |
+
padding: list[int],
|
| 274 |
+
dilation: list[int],
|
| 275 |
+
groups: int,
|
| 276 |
+
) -> torch.Tensor:
|
| 277 |
+
# In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
|
| 278 |
+
# operations before and after the conv2d operation to match the dimension of weights.
|
| 279 |
+
# Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950
|
| 280 |
+
s_inp = torch.ops.aten.unsqueeze(inp, 2)
|
| 281 |
+
conv1d_res = torch.ops.aten.conv2d(
|
| 282 |
+
s_inp,
|
| 283 |
+
weight,
|
| 284 |
+
bias,
|
| 285 |
+
stride,
|
| 286 |
+
padding,
|
| 287 |
+
dilation,
|
| 288 |
+
groups,
|
| 289 |
+
)
|
| 290 |
+
uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2)
|
| 291 |
+
return uns_conv1d_res
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 295 |
+
"""Conv specific transformation function."""
|
| 296 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 297 |
+
opname = node.target._opname
|
| 298 |
+
scale_node, zero_point_node = node.args[2], node.args[3]
|
| 299 |
+
|
| 300 |
+
op_f = (
|
| 301 |
+
torch.ops.aten.conv2d
|
| 302 |
+
if opname in ["conv2d", "conv2d_relu"]
|
| 303 |
+
else _conv1d_op_with_squeeze
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
inp_node, param_node = node.args[0], node.args[1]
|
| 307 |
+
assert isinstance(inp_node, torch.fx.Node)
|
| 308 |
+
assert isinstance(param_node, torch.fx.Node)
|
| 309 |
+
|
| 310 |
+
if param_node.op == "call_function":
|
| 311 |
+
# Using Conv2dPrepackParam from conv_prepack.
|
| 312 |
+
# We directly skip the packing call and inline weights and bias.
|
| 313 |
+
w_node, b_node = param_node.args[0], param_node.args[1]
|
| 314 |
+
assert isinstance(w_node, torch.fx.Node)
|
| 315 |
+
assert b_node is None or isinstance(b_node, torch.fx.Node)
|
| 316 |
+
(
|
| 317 |
+
param_0,
|
| 318 |
+
param_1,
|
| 319 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 320 |
+
gm, w_node, b_node
|
| 321 |
+
)
|
| 322 |
+
op_res_node = gm.graph.call_function(
|
| 323 |
+
op_f, (inp_node, param_0, param_1, *param_node.args[2:])
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
# Using ConvPrepackedParam.
|
| 327 |
+
param = get_script_object(gm, param_node)
|
| 328 |
+
(
|
| 329 |
+
param_0,
|
| 330 |
+
param_1,
|
| 331 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 332 |
+
gm, param_node
|
| 333 |
+
) # type: ignore[assignment]
|
| 334 |
+
op_res_node = gm.graph.call_function(
|
| 335 |
+
op_f,
|
| 336 |
+
(
|
| 337 |
+
inp_node,
|
| 338 |
+
param_0,
|
| 339 |
+
param_1,
|
| 340 |
+
param.stride(), # type: ignore[attr-defined]
|
| 341 |
+
param.padding(), # type: ignore[attr-defined]
|
| 342 |
+
param.dilation(), # type: ignore[attr-defined]
|
| 343 |
+
param.groups(), # type: ignore[attr-defined]
|
| 344 |
+
),
|
| 345 |
+
)
|
| 346 |
+
return op_res_node, scale_node, zero_point_node
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 350 |
+
"""Linear specific transformation function."""
|
| 351 |
+
scale_node, zero_point_node = node.args[2], node.args[3]
|
| 352 |
+
|
| 353 |
+
inp_node, param_node = node.args[0], node.args[1]
|
| 354 |
+
assert isinstance(inp_node, torch.fx.Node)
|
| 355 |
+
assert isinstance(param_node, torch.fx.Node)
|
| 356 |
+
|
| 357 |
+
if param_node.op == "call_function":
|
| 358 |
+
# Using LinearPrepackParam from linear_prepack.
|
| 359 |
+
# We directly skip the packing call and inline weights and bias.
|
| 360 |
+
w_node, b_node = param_node.args[0], param_node.args[1]
|
| 361 |
+
assert isinstance(w_node, torch.fx.Node)
|
| 362 |
+
assert b_node is None or isinstance(b_node, torch.fx.Node)
|
| 363 |
+
(
|
| 364 |
+
param_0,
|
| 365 |
+
param_1,
|
| 366 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
|
| 367 |
+
gm, w_node, b_node
|
| 368 |
+
)
|
| 369 |
+
op_res_node = gm.graph.call_function(
|
| 370 |
+
torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:])
|
| 371 |
+
)
|
| 372 |
+
else:
|
| 373 |
+
# Using LinearPackedParams.
|
| 374 |
+
(
|
| 375 |
+
param_0,
|
| 376 |
+
param_1,
|
| 377 |
+
) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
|
| 378 |
+
gm, param_node
|
| 379 |
+
) # type: ignore[assignment]
|
| 380 |
+
op_res_node = gm.graph.call_function(
|
| 381 |
+
torch.ops.aten.linear, (inp_node, param_0, param_1)
|
| 382 |
+
)
|
| 383 |
+
return op_res_node, scale_node, zero_point_node
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _transform_op_where_last_two_arguments_are_scale_and_zero_point(
|
| 387 |
+
gm: torch.fx.GraphModule, node: torch.fx.Node
|
| 388 |
+
):
|
| 389 |
+
"""
|
| 390 |
+
This transformation function can be used for function where the last two
|
| 391 |
+
parameters are scale and zero point. Additionally, the function's parameters
|
| 392 |
+
do not need any unpacking.
|
| 393 |
+
"""
|
| 394 |
+
to_standard_op = {
|
| 395 |
+
"mul": torch.ops.aten.mul,
|
| 396 |
+
"mul_relu": torch.ops.aten.mul,
|
| 397 |
+
"add": torch.ops.aten.add,
|
| 398 |
+
"add_relu": torch.ops.aten.add,
|
| 399 |
+
"softmax": torch.ops.aten.softmax,
|
| 400 |
+
"cat": torch.ops.aten.cat,
|
| 401 |
+
"hardswish": torch.ops.aten.hardswish,
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 405 |
+
opname, args = node.target._opname, node.args
|
| 406 |
+
scale_node, zero_point_node = args[-2], args[-1]
|
| 407 |
+
op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2]))
|
| 408 |
+
return op_res_node, scale_node, zero_point_node
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 412 |
+
"""Transform scalar overload for basic arithmetic."""
|
| 413 |
+
to_standard_op = {
|
| 414 |
+
"mul": torch.ops.aten.mul.Scalar,
|
| 415 |
+
"add": torch.ops.aten.add.Scalar,
|
| 416 |
+
}
|
| 417 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 418 |
+
opname, args = node.target._opname, node.args
|
| 419 |
+
op_res_node = gm.graph.call_function(to_standard_op[opname], args)
|
| 420 |
+
return op_res_node, _SCALE, _ZERO_POINT
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 424 |
+
"""
|
| 425 |
+
Transformation for functions under prepacked namespace, where they share
|
| 426 |
+
the same handling logic that [...]OpContext contains all parameters.
|
| 427 |
+
"""
|
| 428 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 429 |
+
opname, args = node.target._opname, node.args
|
| 430 |
+
op_f = None
|
| 431 |
+
if opname == "conv2d_clamp_run":
|
| 432 |
+
op_f = torch.ops.aten.conv2d
|
| 433 |
+
elif opname == "linear_clamp_run":
|
| 434 |
+
op_f = torch.ops.aten.linear
|
| 435 |
+
else:
|
| 436 |
+
raise RuntimeError(f"Invalid operator {opname}")
|
| 437 |
+
|
| 438 |
+
assert isinstance(args[1], torch.fx.Node)
|
| 439 |
+
so = get_script_object(gm, args[1])
|
| 440 |
+
|
| 441 |
+
func_args = []
|
| 442 |
+
func_args += [args[0]]
|
| 443 |
+
func_args += so.unpack()[:2] # type: ignore[attr-defined]
|
| 444 |
+
if opname == "conv2d_clamp_run":
|
| 445 |
+
func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]
|
| 446 |
+
|
| 447 |
+
op_res_node = gm.graph.call_function(op_f, tuple(func_args))
|
| 448 |
+
return op_res_node
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
| 452 |
+
args = node.args
|
| 453 |
+
scale_node, zero_point_node = args[-2], args[-1]
|
| 454 |
+
op_res_node = gm.graph.call_function(
|
| 455 |
+
torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3])
|
| 456 |
+
)
|
| 457 |
+
op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0))
|
| 458 |
+
return op_res_node, scale_node, zero_point_node
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def fx_transform_quantized_op_to_standard_op(
|
| 462 |
+
gm: torch.fx.GraphModule, node: torch.fx.Node
|
| 463 |
+
) -> torch.fx.Node:
|
| 464 |
+
global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE
|
| 465 |
+
|
| 466 |
+
assert isinstance(node.target, torch._ops.OpOverload)
|
| 467 |
+
opname, overload = node.target._opname, node.target._overloadname
|
| 468 |
+
|
| 469 |
+
key = f"{opname}.{overload}"
|
| 470 |
+
opname_to_transform_f = {
|
| 471 |
+
"conv1d.new": _transform_conv_with_packedparam,
|
| 472 |
+
"conv1d_relu.new": _transform_conv_with_packedparam,
|
| 473 |
+
"conv1d.default": _transform_conv_with_packedparam,
|
| 474 |
+
"conv1d_relu.default": _transform_conv_with_packedparam,
|
| 475 |
+
"conv2d.new": _transform_conv_with_packedparam,
|
| 476 |
+
"conv2d_relu.new": _transform_conv_with_packedparam,
|
| 477 |
+
"conv2d.default": _transform_conv_with_packedparam,
|
| 478 |
+
"conv2d_relu.default": _transform_conv_with_packedparam,
|
| 479 |
+
"linear.default": _transform_linear_with_packedparam,
|
| 480 |
+
"linear_relu.default": _transform_linear_with_packedparam,
|
| 481 |
+
"add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 482 |
+
"add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 483 |
+
"mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 484 |
+
"mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 485 |
+
"softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 486 |
+
"cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 487 |
+
"hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
|
| 488 |
+
"batch_norm2d.default": _transform_batch_norm,
|
| 489 |
+
"mul.Scalar": _transform_scalar_arithmetic,
|
| 490 |
+
"add.Scalar": _transform_scalar_arithmetic,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
if f"{key}" not in opname_to_transform_f:
|
| 494 |
+
raise RuntimeError(f"Unsupported quantized op during transformation: {key}")
|
| 495 |
+
|
| 496 |
+
op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node)
|
| 497 |
+
|
| 498 |
+
# Add fused activation layer.
|
| 499 |
+
op_res_node = insert_fused_activation_node(gm, opname, op_res_node)
|
| 500 |
+
_SCALE, _ZERO_POINT = scale_node, zero_point_node
|
| 501 |
+
|
| 502 |
+
assert _INPUT_Q_DTYPE is not None
|
| 503 |
+
qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE)
|
| 504 |
+
q_fx_node = insert_quantized_node(
|
| 505 |
+
gm,
|
| 506 |
+
op_res_node,
|
| 507 |
+
scale_node,
|
| 508 |
+
zero_point_node,
|
| 509 |
+
qmin_node,
|
| 510 |
+
qmax_node,
|
| 511 |
+
_INPUT_Q_DTYPE,
|
| 512 |
+
torch.per_tensor_affine,
|
| 513 |
+
)
|
| 514 |
+
dq_fx_node = insert_dequantized_node(
|
| 515 |
+
gm,
|
| 516 |
+
q_fx_node,
|
| 517 |
+
scale_node,
|
| 518 |
+
zero_point_node,
|
| 519 |
+
qmin_node,
|
| 520 |
+
qmax_node,
|
| 521 |
+
_INPUT_Q_DTYPE,
|
| 522 |
+
None,
|
| 523 |
+
torch.per_tensor_affine,
|
| 524 |
+
)
|
| 525 |
+
return dq_fx_node
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
| 529 |
+
"""
|
| 530 |
+
Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with
|
| 531 |
+
PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv).
|
| 532 |
+
|
| 533 |
+
Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y
|
| 534 |
+
|
| 535 |
+
After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y
|
| 536 |
+
|
| 537 |
+
(qd == quantized_decomposed library, q = quantize, dq = dequantize)
|
| 538 |
+
^
|
| 539 |
+
|
|
| 540 |
+
getattr(w), getattr(b) from Conv2dParamPrepack
|
| 541 |
+
|
| 542 |
+
During each iteration, the transformation spits out the transformed operator, its quantized output,
|
| 543 |
+
and its dequantized value together. We did this because dequantization need to use the
|
| 544 |
+
scale and zero point parameters from the quantization to recover the approximate original value. After each
|
| 545 |
+
iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear).
|
| 546 |
+
|
| 547 |
+
For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject.
|
| 548 |
+
During the transformation, we unpack those objects, get their dequantized tensor, populate those
|
| 549 |
+
as attributes to the module, and use getattr to access them.
|
| 550 |
+
|
| 551 |
+
One exception in the transformation is conv_prepack and linear_prepack. Those calls pack
|
| 552 |
+
weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls.
|
| 553 |
+
During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the
|
| 554 |
+
quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
|
| 555 |
+
to the operator by converting them to a getattr fx.node.
|
| 556 |
+
|
| 557 |
+
For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
|
| 558 |
+
without the need of doing de/quantization.
|
| 559 |
+
|
| 560 |
+
Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
|
| 561 |
+
data type, which is the same across the entire program, but it only shows up in the very first quantization
|
| 562 |
+
call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
global _INPUT_Q_DTYPE
|
| 566 |
+
|
| 567 |
+
quantized = False
|
| 568 |
+
|
| 569 |
+
last_quantized_node = None
|
| 570 |
+
# pyrefly: ignore [bad-assignment]
|
| 571 |
+
for node in gm.graph.nodes:
|
| 572 |
+
if isinstance(node.target, OpOverload):
|
| 573 |
+
with gm.graph.inserting_before(node):
|
| 574 |
+
namespace, opname = node.target.namespace, node.target._opname
|
| 575 |
+
if namespace == "quantized" and opname not in [
|
| 576 |
+
"conv_prepack",
|
| 577 |
+
"linear_prepack",
|
| 578 |
+
]:
|
| 579 |
+
quantized = True
|
| 580 |
+
fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
|
| 581 |
+
node.replace_all_uses_with(fx_node)
|
| 582 |
+
last_quantized_node = fx_node
|
| 583 |
+
elif namespace == "prepacked":
|
| 584 |
+
quantized = True
|
| 585 |
+
fx_node = _transform_prepacked_op(gm, node)
|
| 586 |
+
node.replace_all_uses_with(fx_node)
|
| 587 |
+
last_quantized_node = fx_node
|
| 588 |
+
elif namespace == "aten" and opname == "quantize_per_tensor":
|
| 589 |
+
inp_node, scale_node, zero_point_node, dtype_node = node.args
|
| 590 |
+
dtype_node = fx_enum_to_dtype(gm, dtype_node)
|
| 591 |
+
_INPUT_Q_DTYPE = dtype_node
|
| 592 |
+
qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node)
|
| 593 |
+
q_fx_node = insert_quantized_node(
|
| 594 |
+
gm,
|
| 595 |
+
inp_node,
|
| 596 |
+
scale_node,
|
| 597 |
+
zero_point_node,
|
| 598 |
+
qmin_node,
|
| 599 |
+
qmax_node,
|
| 600 |
+
dtype_node,
|
| 601 |
+
torch.per_tensor_affine,
|
| 602 |
+
)
|
| 603 |
+
dq_fx_node = insert_dequantized_node(
|
| 604 |
+
gm,
|
| 605 |
+
q_fx_node,
|
| 606 |
+
scale_node,
|
| 607 |
+
zero_point_node,
|
| 608 |
+
qmin_node,
|
| 609 |
+
qmax_node,
|
| 610 |
+
dtype_node,
|
| 611 |
+
None,
|
| 612 |
+
torch.per_tensor_affine,
|
| 613 |
+
)
|
| 614 |
+
node.replace_all_uses_with(dq_fx_node)
|
| 615 |
+
last_quantized_node = dq_fx_node
|
| 616 |
+
elif namespace == "aten" and opname == "dequantize":
|
| 617 |
+
assert last_quantized_node is not None
|
| 618 |
+
node.replace_all_uses_with(last_quantized_node)
|
| 619 |
+
else:
|
| 620 |
+
last_quantized_node = node
|
| 621 |
+
|
| 622 |
+
# Post-processing again to remove legacy ScriptObjects and quantizated tensors
|
| 623 |
+
# stored as attributes or in the buffer. This is used to clean up the GraphModule
|
| 624 |
+
# to not trigger tracing errors like missing __obj_flatten__ functions.
|
| 625 |
+
def _clean_attr(mod: torch.nn.Module):
|
| 626 |
+
for submod in mod.modules():
|
| 627 |
+
attr_names_to_clean = set()
|
| 628 |
+
for k, v in submod.__dict__.items():
|
| 629 |
+
if isinstance(v, torch.ScriptObject):
|
| 630 |
+
attr_names_to_clean.add(k)
|
| 631 |
+
if k == "_buffers":
|
| 632 |
+
buffer_name_to_clean = set()
|
| 633 |
+
# pyrefly: ignore [missing-attribute]
|
| 634 |
+
for b_name, b_value in v.items():
|
| 635 |
+
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
|
| 636 |
+
torch.qint8,
|
| 637 |
+
torch.quint8,
|
| 638 |
+
]:
|
| 639 |
+
buffer_name_to_clean.add(b_name)
|
| 640 |
+
for b_name in buffer_name_to_clean:
|
| 641 |
+
# pyrefly: ignore [missing-attribute]
|
| 642 |
+
v.pop(b_name, None)
|
| 643 |
+
for attr_name in attr_names_to_clean:
|
| 644 |
+
delattr(submod, attr_name)
|
| 645 |
+
|
| 646 |
+
if quantized:
|
| 647 |
+
"""
|
| 648 |
+
TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily
|
| 649 |
+
bypass test cases.
|
| 650 |
+
|
| 651 |
+
The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing
|
| 652 |
+
will throw errors. However, the current way of SetAttr does inplace update to attributes, so
|
| 653 |
+
this pass regard them as dead code and remove them. Below is an example of GraphModule before
|
| 654 |
+
and after the dead code elimination pass.
|
| 655 |
+
|
| 656 |
+
class GraphModule(torch.nn.Module):
|
| 657 |
+
def forward(self, x_1):
|
| 658 |
+
# No stacktrace found for following nodes
|
| 659 |
+
data = self.data; data = None
|
| 660 |
+
data_1 = self.data
|
| 661 |
+
add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None
|
| 662 |
+
data_2 = self.data
|
| 663 |
+
copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None
|
| 664 |
+
data_3 = self.data
|
| 665 |
+
add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
|
| 666 |
+
return add_tensor_1
|
| 667 |
+
|
| 668 |
+
class GraphModule(torch.nn.Module):
|
| 669 |
+
def forward(self, x_1):
|
| 670 |
+
# No stacktrace found for following nodes
|
| 671 |
+
data_3 = self.data
|
| 672 |
+
add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None
|
| 673 |
+
return add_tensor_1
|
| 674 |
+
"""
|
| 675 |
+
gm.graph.eliminate_dead_code()
|
| 676 |
+
_clean_attr(gm)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_set_grad_with_hop_pass.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
|
| 8 |
+
|
| 9 |
+
from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split
|
| 10 |
+
from .replace_with_hop_pass_util import (
|
| 11 |
+
_replace_with_hop_helper,
|
| 12 |
+
_replace_with_hop_pass_helper,
|
| 13 |
+
_sequential_split_and_maybe_inline_subgraphs_helper,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from torch.export.graph_signature import ExportGraphSignature
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _is_set_grad_enabled_node(node: torch.fx.Node) -> torch.fx.Node | bool:
|
| 22 |
+
return (
|
| 23 |
+
node
|
| 24 |
+
and node.op == "call_function"
|
| 25 |
+
and node.target is torch._C._set_grad_enabled
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _is_set_grad_enabled_sub_mod(
|
| 30 |
+
node: torch.fx.Node, omit_if_same_with_ambient: bool = False
|
| 31 |
+
) -> bool | torch.Tensor:
|
| 32 |
+
if node.op == "call_module":
|
| 33 |
+
assert isinstance(node.target, str)
|
| 34 |
+
subgm = getattr(node.graph.owning_module, node.target)
|
| 35 |
+
first_non_ph = nodes_first(
|
| 36 |
+
subgm.graph.nodes, lambda node: node.op != "placeholder"
|
| 37 |
+
)
|
| 38 |
+
if (
|
| 39 |
+
first_non_ph
|
| 40 |
+
and first_non_ph.op == "call_function"
|
| 41 |
+
and first_non_ph.target is torch._C._set_grad_enabled
|
| 42 |
+
):
|
| 43 |
+
return (
|
| 44 |
+
first_non_ph.args[0] != torch.is_grad_enabled()
|
| 45 |
+
if omit_if_same_with_ambient
|
| 46 |
+
else True
|
| 47 |
+
)
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _replace_with_hop(node: torch.fx.Node) -> None:
|
| 52 |
+
assert node.op == "call_module"
|
| 53 |
+
graph: torch.fx.Graph = node.graph
|
| 54 |
+
assert graph.owning_module is not None
|
| 55 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 56 |
+
assert isinstance(node.target, str)
|
| 57 |
+
sub_gm = getattr(gm, node.target)
|
| 58 |
+
sub_graph = sub_gm.graph
|
| 59 |
+
set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
|
| 60 |
+
if len(set_grad_nodes) > 0:
|
| 61 |
+
assert len(set_grad_nodes) == 1
|
| 62 |
+
set_grad_node = set_grad_nodes[0]
|
| 63 |
+
_replace_with_hop_helper(node, set_grad_node, wrap_with_set_grad_enabled)
|
| 64 |
+
sub_graph.erase_node(set_grad_node)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _remove_set_grad_and_inline(node: torch.fx.Node) -> None:
|
| 68 |
+
assert node.op == "call_module"
|
| 69 |
+
graph: torch.fx.Graph = node.graph
|
| 70 |
+
assert graph.owning_module is not None
|
| 71 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 72 |
+
assert isinstance(node.target, str)
|
| 73 |
+
sub_gm = getattr(gm, node.target)
|
| 74 |
+
sub_graph = sub_gm.graph
|
| 75 |
+
nodes_map(
|
| 76 |
+
sub_graph.nodes,
|
| 77 |
+
lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
|
| 78 |
+
)
|
| 79 |
+
node_inline_(node)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _sequential_split_and_maybe_inline_subgraphs(
|
| 83 |
+
gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
|
| 84 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 85 |
+
"""
|
| 86 |
+
Helper function for replace_set_grad_with_hop_pass().
|
| 87 |
+
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
|
| 88 |
+
For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
|
| 89 |
+
back into the parent graph module.
|
| 90 |
+
"""
|
| 91 |
+
need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes)
|
| 92 |
+
if not need_replacing:
|
| 93 |
+
return gm, graph_signature
|
| 94 |
+
|
| 95 |
+
# sequential_split returns a new graph module that could have different output
|
| 96 |
+
# args names. We need to fix the graph signature.
|
| 97 |
+
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
|
| 98 |
+
|
| 99 |
+
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
|
| 100 |
+
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
|
| 101 |
+
_replace_with_hop(node)
|
| 102 |
+
else:
|
| 103 |
+
_remove_set_grad_and_inline(node)
|
| 104 |
+
|
| 105 |
+
return _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 106 |
+
new_gm, graph_signature, _maybe_inline_or_replace_with_hop
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def replace_set_grad_with_hop_pass(
|
| 111 |
+
gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None
|
| 112 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 113 |
+
"""
|
| 114 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 115 |
+
then recursively call itself on each of the submodules.
|
| 116 |
+
"""
|
| 117 |
+
return _replace_with_hop_pass_helper(
|
| 118 |
+
gm,
|
| 119 |
+
graph_signature,
|
| 120 |
+
_sequential_split_and_maybe_inline_subgraphs,
|
| 121 |
+
)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._export.error import InternalError
|
| 6 |
+
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
| 7 |
+
from torch._ops import HigherOrderOperator, OpOverload
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = ["ReplaceViewOpsWithViewCopyOpsPass"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS: dict[OpOverload, OpOverload] = {
|
| 14 |
+
torch.ops.aten._unsafe_view.default: torch.ops.aten.view_copy.default,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_view_op(schema: torch._C.FunctionSchema) -> bool:
|
| 19 |
+
if len(schema.arguments) == 0:
|
| 20 |
+
return False
|
| 21 |
+
alias_info = schema.arguments[0].alias_info
|
| 22 |
+
return (alias_info is not None) and (not alias_info.is_write)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_view_copy_of_view_op(schema: torch._C.FunctionSchema) -> Optional[OpOverload]:
|
| 26 |
+
if is_view_op(schema) and schema.name.startswith("aten::"):
|
| 27 |
+
view_op_name = schema.name.split("::")[1]
|
| 28 |
+
view_op_overload = (
|
| 29 |
+
schema.overload_name if schema.overload_name != "" else "default"
|
| 30 |
+
)
|
| 31 |
+
view_copy_op_name = view_op_name + "_copy"
|
| 32 |
+
if not hasattr(torch.ops.aten, view_copy_op_name):
|
| 33 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 34 |
+
|
| 35 |
+
view_copy_op_overload_packet = getattr(torch.ops.aten, view_copy_op_name)
|
| 36 |
+
|
| 37 |
+
if not hasattr(view_copy_op_overload_packet, view_op_overload):
|
| 38 |
+
raise InternalError(f"{schema.name} is missing a view_copy variant")
|
| 39 |
+
|
| 40 |
+
return getattr(view_copy_op_overload_packet, view_op_overload)
|
| 41 |
+
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ReplaceViewOpsWithViewCopyOpsPass(_ExportPassBaseDeprecatedDoNotUse):
|
| 46 |
+
"""
|
| 47 |
+
Our backend expects pure functional operators. For efficiency
|
| 48 |
+
purposes, we keep view ops around while functionalizing the exported
|
| 49 |
+
program. This pass replaces view ops with view copy ops for backends that
|
| 50 |
+
need AOT memory planning.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def call_operator(self, op, args, kwargs, meta):
|
| 54 |
+
if op in _NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS:
|
| 55 |
+
return super().call_operator(
|
| 56 |
+
(_NON_FUNCTIONAL_OPS_TO_FUNCTIONAL_OPS[op]), args, kwargs, meta
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
if isinstance(op, HigherOrderOperator):
|
| 60 |
+
return super().call_operator(op, args, kwargs, meta)
|
| 61 |
+
|
| 62 |
+
if view_copy_op := get_view_copy_of_view_op(op._schema):
|
| 63 |
+
return super().call_operator(view_copy_op, args, kwargs, meta)
|
| 64 |
+
|
| 65 |
+
return super().call_operator(op, args, kwargs, meta)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/passes/replace_with_hop_pass_util.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import contextlib
|
| 5 |
+
import copy
|
| 6 |
+
import operator
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from ..utils import node_replace_, nodes_map
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from collections.abc import Callable
|
| 16 |
+
|
| 17 |
+
from torch._ops import HigherOrderOperator
|
| 18 |
+
from torch.export.graph_signature import ExportGraphSignature
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _replace_with_hop_helper(
|
| 22 |
+
node: torch.fx.Node,
|
| 23 |
+
enter_block_node: torch.fx.Node,
|
| 24 |
+
wrap_hoo: HigherOrderOperator,
|
| 25 |
+
) -> None:
|
| 26 |
+
graph: torch.fx.Graph = node.graph
|
| 27 |
+
assert graph.owning_module is not None
|
| 28 |
+
gm: torch.fx.GraphModule = graph.owning_module
|
| 29 |
+
assert isinstance(node.target, str)
|
| 30 |
+
sub_gm = getattr(gm, node.target)
|
| 31 |
+
|
| 32 |
+
def set_hoo_node_meta(call_func_node):
|
| 33 |
+
call_func_node.meta["nn_module_stack"] = copy.copy(
|
| 34 |
+
enter_block_node.meta.get("nn_module_stack", {})
|
| 35 |
+
)
|
| 36 |
+
call_func_node.meta["torch_fn"] = (
|
| 37 |
+
f"{wrap_hoo.__name__}",
|
| 38 |
+
# pyrefly: ignore [missing-attribute]
|
| 39 |
+
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
|
| 40 |
+
)
|
| 41 |
+
if isinstance(output_args, (tuple, list)):
|
| 42 |
+
call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args)
|
| 43 |
+
elif isinstance(output_args, torch.fx.Node):
|
| 44 |
+
call_func_node.meta["val"] = (output_args.meta["val"],)
|
| 45 |
+
|
| 46 |
+
with graph.inserting_before(node):
|
| 47 |
+
get_attr_node = graph.get_attr(node.target)
|
| 48 |
+
get_attr_node.meta["nn_module_stack"] = copy.copy(
|
| 49 |
+
enter_block_node.meta.get("nn_module_stack", {})
|
| 50 |
+
)
|
| 51 |
+
output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
|
| 52 |
+
# Split_module pass intentionally doesn't add output node
|
| 53 |
+
# if the graph doesn't return anything.
|
| 54 |
+
# TODO (tmanlaibaatar) Figure out if this is right behaviour
|
| 55 |
+
# for split_module
|
| 56 |
+
if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
|
| 57 |
+
output_node = None
|
| 58 |
+
if output_node is not None:
|
| 59 |
+
assert len(output_node.args) == 1
|
| 60 |
+
output_args = output_node.args[0]
|
| 61 |
+
enter_block_node_args = enter_block_node.args
|
| 62 |
+
if isinstance(output_args, (tuple, list)):
|
| 63 |
+
call_func_node = graph.call_function(
|
| 64 |
+
wrap_hoo,
|
| 65 |
+
(*enter_block_node_args, get_attr_node, *node.args),
|
| 66 |
+
{},
|
| 67 |
+
)
|
| 68 |
+
# Create the metadata
|
| 69 |
+
set_hoo_node_meta(call_func_node)
|
| 70 |
+
node_replace_(node, call_func_node)
|
| 71 |
+
|
| 72 |
+
# Rename the name of getitem nodes to the actual name of its contents
|
| 73 |
+
# for passing verifier and better readability, also propagate metadata
|
| 74 |
+
for get_item_node in call_func_node.users:
|
| 75 |
+
idx: int = get_item_node.args[1] # type: ignore[assignment]
|
| 76 |
+
output_node = output_args[idx]
|
| 77 |
+
get_item_node._rename(output_node.name)
|
| 78 |
+
get_item_node.meta = output_node.meta
|
| 79 |
+
|
| 80 |
+
elif isinstance(output_args, torch.fx.Node):
|
| 81 |
+
call_func_node = graph.create_node(
|
| 82 |
+
"call_function",
|
| 83 |
+
wrap_hoo,
|
| 84 |
+
(*enter_block_node_args, get_attr_node, *node.args),
|
| 85 |
+
{},
|
| 86 |
+
output_args.name,
|
| 87 |
+
)
|
| 88 |
+
# Modify the subgraph to output a singleton list.
|
| 89 |
+
output_node.args = ((output_args,),)
|
| 90 |
+
# Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph.
|
| 91 |
+
get_item_node = graph.create_node(
|
| 92 |
+
"call_function",
|
| 93 |
+
operator.getitem,
|
| 94 |
+
(call_func_node, 0),
|
| 95 |
+
{},
|
| 96 |
+
)
|
| 97 |
+
# Create the metadata
|
| 98 |
+
get_item_node.meta = output_args.meta
|
| 99 |
+
set_hoo_node_meta(call_func_node)
|
| 100 |
+
node_replace_(node, get_item_node)
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError(
|
| 103 |
+
f"replace_with_hop_pass doesn't support output type {type(output_args)}"
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
# TODO (shangdiy): remove this line, since the export graph can be non-functional
|
| 107 |
+
node.graph.erase_node(node)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _sequential_split_and_maybe_inline_subgraphs_helper(
|
| 111 |
+
new_gm: torch.fx.GraphModule,
|
| 112 |
+
graph_signature: ExportGraphSignature | None,
|
| 113 |
+
maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None],
|
| 114 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 115 |
+
"""
|
| 116 |
+
Helper function for replacing graph nodse with higher order nodes.
|
| 117 |
+
For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls
|
| 118 |
+
back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`.
|
| 119 |
+
"""
|
| 120 |
+
# new_gm is a new graph module that could have different output args names.
|
| 121 |
+
# We need to fix the graph signature.
|
| 122 |
+
replace_ctx = contextlib.nullcontext()
|
| 123 |
+
new_signature = None
|
| 124 |
+
if graph_signature is not None:
|
| 125 |
+
# Cannot deep copy a real ScriptObject, which is referenced
|
| 126 |
+
# in the FakeScriptObject. Copy should be good enough to guard
|
| 127 |
+
# against accidental mutation to original graph_signature.
|
| 128 |
+
new_signature = copy.copy(graph_signature)
|
| 129 |
+
new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output")))
|
| 130 |
+
assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len(
|
| 131 |
+
new_signature.output_specs
|
| 132 |
+
)
|
| 133 |
+
for arg_node, out_spec in zip(
|
| 134 |
+
new_gm_out_node.args[0], new_signature.output_specs
|
| 135 |
+
):
|
| 136 |
+
if arg_node is None:
|
| 137 |
+
assert out_spec.arg.value is None # type: ignore[union-attr]
|
| 138 |
+
elif (
|
| 139 |
+
isinstance(arg_node, torch.fx.Node)
|
| 140 |
+
and out_spec.arg.name != arg_node.name
|
| 141 |
+
):
|
| 142 |
+
out_spec.arg.name = arg_node.name
|
| 143 |
+
|
| 144 |
+
replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment]
|
| 145 |
+
|
| 146 |
+
with replace_ctx:
|
| 147 |
+
nodes_map(
|
| 148 |
+
list(new_gm.graph.nodes),
|
| 149 |
+
lambda node: (
|
| 150 |
+
maybe_inline_or_replace_with_hop(node)
|
| 151 |
+
if node.op == "call_module"
|
| 152 |
+
else node
|
| 153 |
+
),
|
| 154 |
+
)
|
| 155 |
+
new_gm.recompile()
|
| 156 |
+
new_gm.graph.lint()
|
| 157 |
+
return new_gm, new_signature
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _replace_with_hop_pass_helper(
|
| 161 |
+
gm: torch.fx.GraphModule,
|
| 162 |
+
graph_signature: ExportGraphSignature | None,
|
| 163 |
+
sequential_split_and_maybe_inline_subgraphs: Callable[
|
| 164 |
+
[torch.fx.GraphModule, ExportGraphSignature | None],
|
| 165 |
+
tuple[torch.fx.GraphModule, ExportGraphSignature | None],
|
| 166 |
+
],
|
| 167 |
+
) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]:
|
| 168 |
+
"""
|
| 169 |
+
Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
|
| 170 |
+
then recursively call itself on each of the submodules.
|
| 171 |
+
"""
|
| 172 |
+
new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs(
|
| 173 |
+
gm, graph_signature
|
| 174 |
+
)
|
| 175 |
+
# recursively call
|
| 176 |
+
for node in new_gm.graph.nodes:
|
| 177 |
+
if node.op == "get_attr":
|
| 178 |
+
subgm = getattr(new_gm, node.target)
|
| 179 |
+
if not isinstance(subgm, torch.fx.GraphModule):
|
| 180 |
+
continue
|
| 181 |
+
new_subgm, _ = _replace_with_hop_pass_helper(
|
| 182 |
+
subgm,
|
| 183 |
+
None,
|
| 184 |
+
sequential_split_and_maybe_inline_subgraphs,
|
| 185 |
+
)
|
| 186 |
+
setattr(new_gm, node.target, new_subgm)
|
| 187 |
+
|
| 188 |
+
new_gm.recompile()
|
| 189 |
+
new_gm.graph.lint()
|
| 190 |
+
return new_gm, new_signature
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/__init__.py
ADDED
|
File without changes
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/dynamic_shapes.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from typing import Any, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._dynamo.exc import UserError, UserErrorType
|
| 6 |
+
from torch.export.dynamic_shapes import (
|
| 7 |
+
_check_dynamic_shapes,
|
| 8 |
+
_DerivedDim,
|
| 9 |
+
_DimHint,
|
| 10 |
+
_tree_map_with_path,
|
| 11 |
+
Dim,
|
| 12 |
+
)
|
| 13 |
+
from torch.utils._pytree import tree_map
|
| 14 |
+
|
| 15 |
+
from .serialize import _dataclass_to_dict
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclasses.dataclass
|
| 19 |
+
class RootDim:
|
| 20 |
+
"""
|
| 21 |
+
This represents a Dim object.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
min: int
|
| 25 |
+
max: Union[int, None]
|
| 26 |
+
derived: list[str]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclasses.dataclass
|
| 30 |
+
class DynamicShapesSpec:
|
| 31 |
+
"""
|
| 32 |
+
This stores a dynamic_shapes spec for de/serialization.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
|
| 36 |
+
dims: dict[str, RootDim]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _postprocess_serialized_shapes(
|
| 40 |
+
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
| 41 |
+
dims: dict[str, dict[str, Union[int, list[str], None]]],
|
| 42 |
+
to_dict: Optional[bool] = False,
|
| 43 |
+
) -> Union[DynamicShapesSpec, dict[str, Any]]:
|
| 44 |
+
"""
|
| 45 |
+
Sorts dims and dumps to dictionary format.
|
| 46 |
+
"""
|
| 47 |
+
from torch.utils._sympy.numbers import int_oo
|
| 48 |
+
|
| 49 |
+
dims = {
|
| 50 |
+
k: RootDim(
|
| 51 |
+
min=v["min"], # type: ignore[arg-type]
|
| 52 |
+
max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type]
|
| 53 |
+
derived=sorted(v["derived"]), # type: ignore[arg-type]
|
| 54 |
+
)
|
| 55 |
+
for k, v in sorted(dims.items())
|
| 56 |
+
}
|
| 57 |
+
# pyrefly: ignore [bad-argument-type]
|
| 58 |
+
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
|
| 59 |
+
if to_dict:
|
| 60 |
+
return _dataclass_to_dict(spec)
|
| 61 |
+
else:
|
| 62 |
+
return spec
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _dump_dynamic_shapes(
|
| 66 |
+
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
|
| 67 |
+
args: tuple[Any],
|
| 68 |
+
kwargs: Optional[dict[str, Any]] = None,
|
| 69 |
+
to_dict: Optional[bool] = False,
|
| 70 |
+
) -> Union[DynamicShapesSpec, dict[str, Any]]:
|
| 71 |
+
"""
|
| 72 |
+
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
|
| 73 |
+
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
|
| 74 |
+
Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
|
| 75 |
+
|
| 76 |
+
dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
|
| 77 |
+
- Each tensor input is represented with a list of values, non-tensor inputs with None.
|
| 78 |
+
- dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
|
| 79 |
+
- static dimensions are represented with ints.
|
| 80 |
+
|
| 81 |
+
dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
|
| 82 |
+
|
| 83 |
+
For example:
|
| 84 |
+
```
|
| 85 |
+
dx = Dim("dx", min=4, max=16)
|
| 86 |
+
dy = dx + 1
|
| 87 |
+
|
| 88 |
+
inputs = (
|
| 89 |
+
[
|
| 90 |
+
torch.randn(4, 4),
|
| 91 |
+
torch.randn(5, 4),
|
| 92 |
+
],
|
| 93 |
+
torch.randn(4),
|
| 94 |
+
torch.randn(4, 4),
|
| 95 |
+
"hello",
|
| 96 |
+
)
|
| 97 |
+
dynamic_shapes = {
|
| 98 |
+
"a": [
|
| 99 |
+
(dx, 4),
|
| 100 |
+
(dy, 4),
|
| 101 |
+
],
|
| 102 |
+
"b": (Dim.STATIC,),
|
| 103 |
+
"c": None,
|
| 104 |
+
"d": None,
|
| 105 |
+
}
|
| 106 |
+
out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
|
| 107 |
+
```
|
| 108 |
+
would generate the following output:
|
| 109 |
+
```
|
| 110 |
+
{
|
| 111 |
+
"dynamic_shapes": (
|
| 112 |
+
[
|
| 113 |
+
["dx", 4],
|
| 114 |
+
["dx + 1", 4],
|
| 115 |
+
],
|
| 116 |
+
["_DimHint.STATIC"],
|
| 117 |
+
["_DimHint.STATIC", "_DimHint.STATIC"],
|
| 118 |
+
None,
|
| 119 |
+
),
|
| 120 |
+
"dims": {
|
| 121 |
+
"dx": {
|
| 122 |
+
"min": 4,
|
| 123 |
+
"max": 16,
|
| 124 |
+
"derived": ["dx + 1"],
|
| 125 |
+
},
|
| 126 |
+
},
|
| 127 |
+
}
|
| 128 |
+
```
|
| 129 |
+
"""
|
| 130 |
+
dims: dict[str, dict[str, Any]] = {}
|
| 131 |
+
|
| 132 |
+
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
|
| 133 |
+
"""
|
| 134 |
+
Helps standardize the dynamic_shapes tree structure we serialize,
|
| 135 |
+
returning lists for each tensor shape, handling tensor-level Nones.
|
| 136 |
+
"""
|
| 137 |
+
if not isinstance(tensor, torch.Tensor):
|
| 138 |
+
return None
|
| 139 |
+
if shape is None:
|
| 140 |
+
return [Dim.STATIC] * len(tensor.shape)
|
| 141 |
+
|
| 142 |
+
out = []
|
| 143 |
+
if isinstance(shape, dict):
|
| 144 |
+
for i, s in enumerate(tensor.shape):
|
| 145 |
+
out.append(s if shape.get(i) is None else shape.get(i))
|
| 146 |
+
else:
|
| 147 |
+
assert isinstance(shape, (tuple, list))
|
| 148 |
+
for i, s in enumerate(tensor.shape):
|
| 149 |
+
out.append(s if shape[i] is None else shape[i])
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
def _track_dim_from_dims(
|
| 153 |
+
val: Union[None, int, _DimHint, Dim],
|
| 154 |
+
) -> Union[None, int, str]:
|
| 155 |
+
"""
|
| 156 |
+
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
|
| 157 |
+
"""
|
| 158 |
+
if val is None or isinstance(val, int): # non-tensor input or static
|
| 159 |
+
return val
|
| 160 |
+
if isinstance(val, _DimHint): # store enum as string
|
| 161 |
+
return val.__class__.__name__ + "." + val.type.name
|
| 162 |
+
|
| 163 |
+
assert isinstance(val, Dim)
|
| 164 |
+
|
| 165 |
+
# track root dim
|
| 166 |
+
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
|
| 167 |
+
if root.__name__ not in dims:
|
| 168 |
+
dims[root.__name__] = {
|
| 169 |
+
"min": root.min, # type: ignore[attr-defined,union-attr]
|
| 170 |
+
"max": root.max, # type: ignore[attr-defined,union-attr]
|
| 171 |
+
"derived": set(),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
# track derived dims
|
| 175 |
+
if isinstance(val, _DerivedDim):
|
| 176 |
+
dims[root.__name__]["derived"].add(val.__name__)
|
| 177 |
+
|
| 178 |
+
return val.__name__
|
| 179 |
+
|
| 180 |
+
if dynamic_shapes is None:
|
| 181 |
+
return {"dynamic_shapes": None, "dims": {}}
|
| 182 |
+
|
| 183 |
+
# convert to tuple of specs, for each arg/kwarg
|
| 184 |
+
kwargs = kwargs or {}
|
| 185 |
+
if isinstance(dynamic_shapes, dict):
|
| 186 |
+
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
|
| 187 |
+
# pyrefly: ignore [bad-assignment, bad-argument-type]
|
| 188 |
+
dynamic_shapes = tuple(dynamic_shapes)
|
| 189 |
+
combined_args = tuple(args) + tuple(kwargs.values())
|
| 190 |
+
|
| 191 |
+
# run same check when we're processing shapes for export - is this too lazy?
|
| 192 |
+
_check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type]
|
| 193 |
+
|
| 194 |
+
tree_shapes = _tree_map_with_path(
|
| 195 |
+
_standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
|
| 196 |
+
)
|
| 197 |
+
serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
|
| 198 |
+
return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _load_dynamic_shapes(
|
| 202 |
+
spec: Union[DynamicShapesSpec, dict[str, Any]],
|
| 203 |
+
from_dict: Optional[bool] = False,
|
| 204 |
+
) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
|
| 205 |
+
"""
|
| 206 |
+
Utility function for dynamic shapes serialization.
|
| 207 |
+
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
|
| 208 |
+
"""
|
| 209 |
+
import sympy
|
| 210 |
+
|
| 211 |
+
from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
|
| 212 |
+
|
| 213 |
+
if from_dict:
|
| 214 |
+
if not isinstance(spec, dict):
|
| 215 |
+
raise UserError(
|
| 216 |
+
UserErrorType.INVALID_INPUT,
|
| 217 |
+
f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
|
| 218 |
+
)
|
| 219 |
+
if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
|
| 220 |
+
raise UserError(
|
| 221 |
+
UserErrorType.INVALID_INPUT,
|
| 222 |
+
"With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
|
| 223 |
+
f"instead found {spec.keys()}",
|
| 224 |
+
)
|
| 225 |
+
dims = {}
|
| 226 |
+
for k, v in spec["dims"].items():
|
| 227 |
+
if not isinstance(k, str):
|
| 228 |
+
raise UserError(
|
| 229 |
+
UserErrorType.INVALID_INPUT,
|
| 230 |
+
f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
|
| 231 |
+
)
|
| 232 |
+
if sorted(v.keys()) != ["derived", "max", "min"]:
|
| 233 |
+
raise UserError(
|
| 234 |
+
UserErrorType.INVALID_INPUT,
|
| 235 |
+
f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
|
| 236 |
+
f"instead found {v.keys()}",
|
| 237 |
+
)
|
| 238 |
+
if not isinstance(v["min"], int):
|
| 239 |
+
raise UserError(
|
| 240 |
+
UserErrorType.INVALID_INPUT,
|
| 241 |
+
f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
|
| 242 |
+
)
|
| 243 |
+
if not isinstance(v["max"], int) or v["max"] is None:
|
| 244 |
+
raise UserError(
|
| 245 |
+
UserErrorType.INVALID_INPUT,
|
| 246 |
+
f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
|
| 247 |
+
)
|
| 248 |
+
if not isinstance(v["derived"], list) or any(
|
| 249 |
+
not isinstance(d, str) for d in v["derived"]
|
| 250 |
+
):
|
| 251 |
+
raise UserError(
|
| 252 |
+
UserErrorType.INVALID_INPUT,
|
| 253 |
+
"Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
|
| 254 |
+
f"got {k}: {v['derived']}",
|
| 255 |
+
)
|
| 256 |
+
dims[k] = RootDim(**v)
|
| 257 |
+
dynamic_shapes = spec["dynamic_shapes"]
|
| 258 |
+
else:
|
| 259 |
+
if not isinstance(spec, DynamicShapesSpec):
|
| 260 |
+
raise UserError(
|
| 261 |
+
UserErrorType.INVALID_INPUT,
|
| 262 |
+
f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
|
| 263 |
+
)
|
| 264 |
+
dims = spec.dims
|
| 265 |
+
dynamic_shapes = spec.dynamic_shapes
|
| 266 |
+
|
| 267 |
+
if dynamic_shapes is None:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
dim_cache = {}
|
| 271 |
+
for name, info in dims.items():
|
| 272 |
+
symbol = sympy.sympify(name)
|
| 273 |
+
if not isinstance(symbol, sympy.Symbol):
|
| 274 |
+
raise UserError(
|
| 275 |
+
UserErrorType.INVALID_INPUT,
|
| 276 |
+
f"Expected `spec['dims']` keys to be symbols, got {name}",
|
| 277 |
+
)
|
| 278 |
+
dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim
|
| 279 |
+
for _expr in info.derived:
|
| 280 |
+
expr = sympy.sympify(_expr)
|
| 281 |
+
if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
|
| 282 |
+
raise UserError(
|
| 283 |
+
UserErrorType.INVALID_INPUT,
|
| 284 |
+
f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
|
| 285 |
+
)
|
| 286 |
+
if not _is_supported_equivalence(expr):
|
| 287 |
+
raise UserError(
|
| 288 |
+
UserErrorType.INVALID_INPUT,
|
| 289 |
+
f"Expected derived expressions to be linear expressions, got {expr}",
|
| 290 |
+
)
|
| 291 |
+
modulus, remainder = sympy.polys.polytools.div(expr, symbol)
|
| 292 |
+
ddim = dim_cache[name]
|
| 293 |
+
if modulus != 1:
|
| 294 |
+
ddim = int(modulus) * ddim # type: ignore[assignment, operator]
|
| 295 |
+
if remainder != 0:
|
| 296 |
+
ddim = ddim + int(remainder) # type: ignore[assignment, operator]
|
| 297 |
+
dim_cache[_expr] = ddim # cache derived dims
|
| 298 |
+
|
| 299 |
+
def deserialize_shape(
|
| 300 |
+
val: Union[None, int, str],
|
| 301 |
+
) -> Union[None, int, Dim, _DimHint]:
|
| 302 |
+
if val is None or isinstance(val, int):
|
| 303 |
+
return val
|
| 304 |
+
elif val == "_DimHint.AUTO":
|
| 305 |
+
return _DimHint.AUTO()
|
| 306 |
+
elif val == "_DimHint.DYNAMIC":
|
| 307 |
+
return _DimHint.DYNAMIC()
|
| 308 |
+
elif val == "_DimHint.STATIC":
|
| 309 |
+
return _DimHint.STATIC()
|
| 310 |
+
if not isinstance(val, str):
|
| 311 |
+
raise UserError(
|
| 312 |
+
UserErrorType.INVALID_INPUT,
|
| 313 |
+
"Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
|
| 314 |
+
f" or derived expressions, got {val}",
|
| 315 |
+
)
|
| 316 |
+
if val not in dim_cache:
|
| 317 |
+
raise UserError(
|
| 318 |
+
UserErrorType.INVALID_INPUT,
|
| 319 |
+
"Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
|
| 320 |
+
f"got {val} which is not in {dims.keys()}",
|
| 321 |
+
)
|
| 322 |
+
return dim_cache[val] # type: ignore[return-value]
|
| 323 |
+
|
| 324 |
+
return tree_map(deserialize_shape, dynamic_shapes)
|
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/serde/export_schema.thrift
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// @generated by update_schema.py
|
| 2 |
+
// checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>>
|
| 3 |
+
|
| 4 |
+
namespace py3 torch._export
|
| 5 |
+
namespace cpp2 torch._export.schema
|
| 6 |
+
|
| 7 |
+
enum ArgumentKind {
|
| 8 |
+
UNKNOWN = 0,
|
| 9 |
+
POSITIONAL = 1,
|
| 10 |
+
KEYWORD = 2,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
enum Layout {
|
| 15 |
+
Unknown = 0,
|
| 16 |
+
SparseCoo = 1,
|
| 17 |
+
SparseCsr = 2,
|
| 18 |
+
SparseCsc = 3,
|
| 19 |
+
SparseBsr = 4,
|
| 20 |
+
SparseBsc = 5,
|
| 21 |
+
_mkldnn = 6,
|
| 22 |
+
Strided = 7,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
enum MemoryFormat {
|
| 27 |
+
Unknown = 0,
|
| 28 |
+
ContiguousFormat = 1,
|
| 29 |
+
ChannelsLast = 2,
|
| 30 |
+
ChannelsLast3d = 3,
|
| 31 |
+
PreserveFormat = 4,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
enum ScalarType {
|
| 36 |
+
UNKNOWN = 0,
|
| 37 |
+
BYTE = 1,
|
| 38 |
+
CHAR = 2,
|
| 39 |
+
SHORT = 3,
|
| 40 |
+
INT = 4,
|
| 41 |
+
LONG = 5,
|
| 42 |
+
HALF = 6,
|
| 43 |
+
FLOAT = 7,
|
| 44 |
+
DOUBLE = 8,
|
| 45 |
+
COMPLEXHALF = 9,
|
| 46 |
+
COMPLEXFLOAT = 10,
|
| 47 |
+
COMPLEXDOUBLE = 11,
|
| 48 |
+
BOOL = 12,
|
| 49 |
+
BFLOAT16 = 13,
|
| 50 |
+
UINT16 = 28,
|
| 51 |
+
FLOAT8E4M3FN = 29,
|
| 52 |
+
FLOAT8E5M2 = 30,
|
| 53 |
+
FLOAT8E4M3FNUZ = 31,
|
| 54 |
+
FLOAT8E5M2FNUZ = 32,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
struct Device {
|
| 59 |
+
10: string type;
|
| 60 |
+
20: optional i64 index;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
union SymExprHint {
|
| 64 |
+
10: i64 as_int;
|
| 65 |
+
20: bool as_bool;
|
| 66 |
+
30: double as_float;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
struct SymExpr {
|
| 70 |
+
10: string expr_str;
|
| 71 |
+
20: optional SymExprHint hint;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
union SymInt {
|
| 75 |
+
10: SymExpr as_expr;
|
| 76 |
+
20: i64 as_int;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
union SymFloat {
|
| 80 |
+
10: SymExpr as_expr;
|
| 81 |
+
20: double as_float;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
union SymBool {
|
| 85 |
+
10: SymExpr as_expr;
|
| 86 |
+
20: bool as_bool;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
struct TensorMeta {
|
| 90 |
+
10: ScalarType dtype;
|
| 91 |
+
20: list<SymInt> sizes;
|
| 92 |
+
30: bool requires_grad;
|
| 93 |
+
40: Device device;
|
| 94 |
+
50: list<SymInt> strides;
|
| 95 |
+
60: SymInt storage_offset;
|
| 96 |
+
70: Layout layout;
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
union SymIntArgument {
|
| 100 |
+
10: string as_name;
|
| 101 |
+
20: i64 as_int;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
union SymFloatArgument {
|
| 105 |
+
10: string as_name;
|
| 106 |
+
20: double as_float;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
union SymBoolArgument {
|
| 110 |
+
10: string as_name;
|
| 111 |
+
20: bool as_bool;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
struct TensorArgument {
|
| 115 |
+
10: string name;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
struct TokenArgument {
|
| 119 |
+
10: string name;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
union OptionalTensorArgument {
|
| 123 |
+
20: TensorArgument as_tensor;
|
| 124 |
+
10: bool as_none;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
struct GraphArgument {
|
| 128 |
+
10: string name;
|
| 129 |
+
20: Graph graph;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
struct CustomObjArgument {
|
| 133 |
+
10: string name;
|
| 134 |
+
20: string class_fqn;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
struct ComplexValue {
|
| 138 |
+
10: double real;
|
| 139 |
+
20: double imag;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
union Argument {
|
| 143 |
+
10: bool as_none;
|
| 144 |
+
20: TensorArgument as_tensor;
|
| 145 |
+
30: list<TensorArgument> as_tensors;
|
| 146 |
+
50: i64 as_int;
|
| 147 |
+
70: list<i64> as_ints;
|
| 148 |
+
80: double as_float;
|
| 149 |
+
90: list<double> as_floats;
|
| 150 |
+
100: string as_string;
|
| 151 |
+
101: list<string> as_strings;
|
| 152 |
+
110: SymIntArgument as_sym_int;
|
| 153 |
+
120: list<SymIntArgument> as_sym_ints;
|
| 154 |
+
130: ScalarType as_scalar_type;
|
| 155 |
+
140: MemoryFormat as_memory_format;
|
| 156 |
+
150: Layout as_layout;
|
| 157 |
+
160: Device as_device;
|
| 158 |
+
170: bool as_bool;
|
| 159 |
+
180: list<bool> as_bools;
|
| 160 |
+
182: SymBoolArgument as_sym_bool;
|
| 161 |
+
184: list<SymBoolArgument> as_sym_bools;
|
| 162 |
+
200: GraphArgument as_graph;
|
| 163 |
+
190: list<OptionalTensorArgument> as_optional_tensors;
|
| 164 |
+
210: CustomObjArgument as_custom_obj;
|
| 165 |
+
220: string as_operator;
|
| 166 |
+
230: SymFloatArgument as_sym_float;
|
| 167 |
+
240: list<SymFloatArgument> as_sym_floats;
|
| 168 |
+
250: OptionalTensorArgument as_optional_tensor;
|
| 169 |
+
260: ComplexValue as_complex;
|
| 170 |
+
280: list<list<i64>> as_int_lists;
|
| 171 |
+
290: map<string, Argument> as_string_to_argument;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
struct NamedArgument {
|
| 175 |
+
10: string name;
|
| 176 |
+
20: Argument arg;
|
| 177 |
+
30: optional ArgumentKind kind;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
struct Node {
|
| 181 |
+
10: string target;
|
| 182 |
+
20: list<NamedArgument> inputs;
|
| 183 |
+
30: list<Argument> outputs;
|
| 184 |
+
40: map<string, string> metadata;
|
| 185 |
+
50: optional bool is_hop_single_tensor_return;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
struct Graph {
|
| 189 |
+
10: list<Argument> inputs;
|
| 190 |
+
20: list<Argument> outputs;
|
| 191 |
+
30: list<Node> nodes;
|
| 192 |
+
40: map<string, TensorMeta> tensor_values;
|
| 193 |
+
50: map<string, SymInt> sym_int_values;
|
| 194 |
+
60: map<string, SymBool> sym_bool_values;
|
| 195 |
+
70: bool is_single_tensor_return;
|
| 196 |
+
80: map<string, CustomObjArgument> custom_obj_values;
|
| 197 |
+
90: map<string, SymFloat> sym_float_values;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
struct UserInputSpec {
|
| 201 |
+
10: Argument arg;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
union ConstantValue {
|
| 205 |
+
10: bool as_none;
|
| 206 |
+
20: i64 as_int;
|
| 207 |
+
30: double as_float;
|
| 208 |
+
40: string as_string;
|
| 209 |
+
50: bool as_bool;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
struct InputToConstantInputSpec {
|
| 213 |
+
10: string name;
|
| 214 |
+
20: ConstantValue value;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
struct InputToParameterSpec {
|
| 218 |
+
10: TensorArgument arg;
|
| 219 |
+
20: string parameter_name;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
struct InputToBufferSpec {
|
| 223 |
+
10: TensorArgument arg;
|
| 224 |
+
20: string buffer_name;
|
| 225 |
+
30: bool persistent;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
struct InputToTensorConstantSpec {
|
| 229 |
+
10: TensorArgument arg;
|
| 230 |
+
20: string tensor_constant_name;
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
struct InputToCustomObjSpec {
|
| 234 |
+
10: CustomObjArgument arg;
|
| 235 |
+
20: string custom_obj_name;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
struct InputTokenSpec {
|
| 239 |
+
10: TokenArgument arg;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
union InputSpec {
|
| 243 |
+
10: UserInputSpec user_input;
|
| 244 |
+
20: InputToParameterSpec parameter;
|
| 245 |
+
30: InputToBufferSpec buffer;
|
| 246 |
+
40: InputToTensorConstantSpec tensor_constant;
|
| 247 |
+
50: InputToCustomObjSpec custom_obj;
|
| 248 |
+
70: InputTokenSpec token;
|
| 249 |
+
60: InputToConstantInputSpec constant_input;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
struct UserOutputSpec {
|
| 253 |
+
10: Argument arg;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
struct LossOutputSpec {
|
| 257 |
+
10: TensorArgument arg;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
struct BufferMutationSpec {
|
| 261 |
+
10: TensorArgument arg;
|
| 262 |
+
20: string buffer_name;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
struct ParameterMutationSpec {
|
| 266 |
+
10: TensorArgument arg;
|
| 267 |
+
20: string parameter_name;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
struct GradientToParameterSpec {
|
| 271 |
+
10: TensorArgument arg;
|
| 272 |
+
20: string parameter_name;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
struct GradientToUserInputSpec {
|
| 276 |
+
10: TensorArgument arg;
|
| 277 |
+
20: string user_input_name;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
struct UserInputMutationSpec {
|
| 281 |
+
10: TensorArgument arg;
|
| 282 |
+
20: string user_input_name;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
struct OutputTokenSpec {
|
| 286 |
+
10: TokenArgument arg;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
union OutputSpec {
|
| 290 |
+
10: UserOutputSpec user_output;
|
| 291 |
+
20: LossOutputSpec loss_output;
|
| 292 |
+
30: BufferMutationSpec buffer_mutation;
|
| 293 |
+
40: GradientToParameterSpec gradient_to_parameter;
|
| 294 |
+
50: GradientToUserInputSpec gradient_to_user_input;
|
| 295 |
+
60: UserInputMutationSpec user_input_mutation;
|
| 296 |
+
70: OutputTokenSpec token;
|
| 297 |
+
80: ParameterMutationSpec parameter_mutation;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
struct GraphSignature {
|
| 301 |
+
10: list<InputSpec> input_specs;
|
| 302 |
+
20: list<OutputSpec> output_specs;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
struct RangeConstraint {
|
| 306 |
+
10: optional i64 min_val;
|
| 307 |
+
20: optional i64 max_val;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
struct ModuleCallSignature {
|
| 311 |
+
10: list<Argument> inputs;
|
| 312 |
+
20: list<Argument> outputs;
|
| 313 |
+
30: string in_spec;
|
| 314 |
+
40: string out_spec;
|
| 315 |
+
50: optional list<string> forward_arg_names;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
struct ModuleCallEntry {
|
| 319 |
+
10: string fqn;
|
| 320 |
+
30: optional ModuleCallSignature signature;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
struct NamedTupleDef {
|
| 324 |
+
10: list<string> field_names;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
struct GraphModule {
|
| 328 |
+
10: Graph graph;
|
| 329 |
+
50: GraphSignature signature;
|
| 330 |
+
60: list<ModuleCallEntry> module_call_graph;
|
| 331 |
+
40: map<string, string> metadata;
|
| 332 |
+
70: map<string, NamedTupleDef> treespec_namedtuple_fields;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
struct SchemaVersion {
|
| 336 |
+
10: i64 major;
|
| 337 |
+
20: i64 minor;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
struct ExportedProgram {
|
| 341 |
+
10: GraphModule graph_module;
|
| 342 |
+
20: map<string, i64> opset_version;
|
| 343 |
+
30: map<string, RangeConstraint> range_constraints;
|
| 344 |
+
60: SchemaVersion schema_version;
|
| 345 |
+
70: list<string> verifiers;
|
| 346 |
+
80: string torch_version;
|
| 347 |
+
90: list<string> guards_code;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
struct PayloadMeta {
|
| 351 |
+
10: string path_name;
|
| 352 |
+
20: bool is_param;
|
| 353 |
+
30: bool use_pickle;
|
| 354 |
+
40: optional TensorMeta tensor_meta;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
struct PayloadConfig {
|
| 358 |
+
10: map<string, PayloadMeta> config;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
struct AOTInductorModelPickleData {
|
| 362 |
+
1: string library_basename;
|
| 363 |
+
2: list<string> input_names;
|
| 364 |
+
3: list<string> output_names;
|
| 365 |
+
4: optional i64 floating_point_input_dtype;
|
| 366 |
+
5: optional i64 floating_point_output_dtype;
|
| 367 |
+
6: optional bool aot_inductor_model_is_cpu;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
struct ExternKernelNode {
|
| 371 |
+
10: string name;
|
| 372 |
+
20: Node node;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
struct ExternKernelNodes {
|
| 376 |
+
10: list<ExternKernelNode> nodes;
|
| 377 |
+
}
|