Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py +274 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py +1230 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py +526 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py +275 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py +93 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py +14 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py +29 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py +4 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py +3 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py +11 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py +110 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py +731 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py +257 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py +514 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py +871 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py +273 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py +88 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py +15 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py +1 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py +7 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py +5 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc +0 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_custom_op/autograd.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils._pytree as pytree
|
| 3 |
+
from collections import namedtuple
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# NOTE [CustomOp autograd kernel indirection]
|
| 8 |
+
# We register `inner` as the autograd kernel for this custom_op.
|
| 9 |
+
# `inner` either calls the autograd formula registered by the user,
|
| 10 |
+
# or goes into an `autograd_not_implemented` kernel.
|
| 11 |
+
#
|
| 12 |
+
# The reason why this indirection exists is
|
| 13 |
+
# so that we can swap out the autograd kernel (the PyTorch dispatcher
|
| 14 |
+
# doesn't actually allow us to do this). By default, we want
|
| 15 |
+
# the `autograd_not_implemented` behavior, but then the user may come
|
| 16 |
+
# and register something that is actually a backward formula
|
| 17 |
+
def autograd_kernel_indirection(custom_op):
|
| 18 |
+
autograd_fallback = autograd_not_implemented(custom_op)
|
| 19 |
+
|
| 20 |
+
def inner(*args, **kwargs):
|
| 21 |
+
if custom_op._has_impl('autograd'):
|
| 22 |
+
kernel = custom_op._get_impl('autograd').func
|
| 23 |
+
return kernel(*args, **kwargs)
|
| 24 |
+
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
|
| 25 |
+
# after the user gives us "backward" and "save_for_backward", we generate
|
| 26 |
+
# the "autograd" impl. If the user only provided one, then we tell
|
| 27 |
+
# the user they've done something wrong.
|
| 28 |
+
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
|
| 29 |
+
missing = (
|
| 30 |
+
'save_for_backward' if custom_op._has_impl('backward')
|
| 31 |
+
else 'backward'
|
| 32 |
+
)
|
| 33 |
+
found = 'save_for_backward' if missing == 'backward' else 'backward'
|
| 34 |
+
loc = custom_op._get_impl(found).location
|
| 35 |
+
raise RuntimeError(
|
| 36 |
+
f"We found a '{found}' registration for {custom_op} at "
|
| 37 |
+
f"{loc} but were unable to find a '{missing}' registration. "
|
| 38 |
+
f"To use the CustomOp API to register a backward formula, "
|
| 39 |
+
f"please provide us both a backward function and a "
|
| 40 |
+
f"'save for backward' function via `impl_backward` and "
|
| 41 |
+
f"`impl_save_for_backward` respectively.")
|
| 42 |
+
return autograd_fallback(*args, **kwargs)
|
| 43 |
+
return inner
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
|
| 47 |
+
# or change the default autograd fallback to the autograd not implemented fallback.
|
| 48 |
+
def autograd_not_implemented(custom_op):
|
| 49 |
+
def kernel(*args, **kwargs):
|
| 50 |
+
if torch.is_grad_enabled() and pytree.tree_any(
|
| 51 |
+
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
|
| 52 |
+
):
|
| 53 |
+
raise RuntimeError("Autograd has not been implemented for operator")
|
| 54 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 55 |
+
return custom_op(*args, **kwargs)
|
| 56 |
+
return kernel
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def mark_non_differentiable(ctx, output, output_differentiability):
|
| 60 |
+
# Output types are restricted to be:
|
| 61 |
+
# - Tensor
|
| 62 |
+
# - Tensor[]
|
| 63 |
+
# - int, bool, Scalar, float
|
| 64 |
+
# See _check_can_register_backward
|
| 65 |
+
if output_differentiability is not None:
|
| 66 |
+
if not isinstance(output, tuple):
|
| 67 |
+
tuple_output = (output,)
|
| 68 |
+
else:
|
| 69 |
+
tuple_output = output # type: ignore[assignment]
|
| 70 |
+
assert len(output_differentiability) == len(tuple_output)
|
| 71 |
+
non_differentiable_tensors = []
|
| 72 |
+
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
|
| 73 |
+
if isinstance(out, torch.Tensor):
|
| 74 |
+
if not differentiable:
|
| 75 |
+
non_differentiable_tensors.append(out)
|
| 76 |
+
continue
|
| 77 |
+
if isinstance(out, list):
|
| 78 |
+
if not differentiable:
|
| 79 |
+
non_differentiable_tensors.extend(out)
|
| 80 |
+
continue
|
| 81 |
+
if differentiable:
|
| 82 |
+
raise RuntimeError(
|
| 83 |
+
f"With output_differentiability={output_differentiability}. "
|
| 84 |
+
f"At idx {idx}, we received an object of type {type(out)} that "
|
| 85 |
+
f"is not a Tensor, so it cannot have be marked as differentiable in "
|
| 86 |
+
f"output_differentiability.")
|
| 87 |
+
if non_differentiable_tensors:
|
| 88 |
+
ctx.mark_non_differentiable(*non_differentiable_tensors)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def construct_autograd_kernel(
|
| 92 |
+
schema,
|
| 93 |
+
output_differentiability,
|
| 94 |
+
custom_op,
|
| 95 |
+
op_overload,
|
| 96 |
+
save_for_backward_fn,
|
| 97 |
+
backward_fn):
|
| 98 |
+
|
| 99 |
+
def apply(*args):
|
| 100 |
+
flat_args, spec = pytree.tree_flatten(args)
|
| 101 |
+
out_spec = None
|
| 102 |
+
|
| 103 |
+
def forward(ctx, *flat_args):
|
| 104 |
+
ctx.set_materialize_grads(True)
|
| 105 |
+
args = pytree.tree_unflatten(list(flat_args), spec)
|
| 106 |
+
with torch._C._AutoDispatchBelowAutograd():
|
| 107 |
+
output = op_overload(*args)
|
| 108 |
+
|
| 109 |
+
# We use the info about args to give better error messages in backward
|
| 110 |
+
args_info = namedtuple_args(
|
| 111 |
+
schema, pytree.tree_map(type, args))
|
| 112 |
+
|
| 113 |
+
save_for_backward_fn_inputs = namedtuple_args(schema, args)
|
| 114 |
+
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
|
| 115 |
+
|
| 116 |
+
save_pytree_for_backward(ctx, (to_save, args_info))
|
| 117 |
+
mark_non_differentiable(ctx, output, output_differentiability)
|
| 118 |
+
|
| 119 |
+
nonlocal out_spec
|
| 120 |
+
flat_output, out_spec = pytree.tree_flatten(output)
|
| 121 |
+
return tuple(flat_output)
|
| 122 |
+
|
| 123 |
+
def backward(ctx, *flat_grad_output):
|
| 124 |
+
assert out_spec is not None
|
| 125 |
+
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
|
| 126 |
+
saved, args_info = unpack_saved(ctx)
|
| 127 |
+
# There is nothing on the ctx object for now, it is just there so
|
| 128 |
+
# that we can add additional things in the future.
|
| 129 |
+
inner_ctx = object()
|
| 130 |
+
if not isinstance(grads, tuple):
|
| 131 |
+
grads = (grads,)
|
| 132 |
+
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
|
| 133 |
+
|
| 134 |
+
# Massage the grad_inputs_dict to a form acceptable by
|
| 135 |
+
# autograd.Function.
|
| 136 |
+
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
|
| 137 |
+
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
|
| 138 |
+
|
| 139 |
+
generated_cls = gen_autograd_function(
|
| 140 |
+
custom_op._opname + '_customop', forward, backward)
|
| 141 |
+
|
| 142 |
+
flat_output = generated_cls.apply(*flat_args)
|
| 143 |
+
assert out_spec is not None
|
| 144 |
+
return pytree.tree_unflatten(list(flat_output), out_spec)
|
| 145 |
+
return apply
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def gen_autograd_function(name, forward, backward):
|
| 149 |
+
generated_cls = type(
|
| 150 |
+
name,
|
| 151 |
+
(torch.autograd.Function,),
|
| 152 |
+
{
|
| 153 |
+
'forward': staticmethod(forward),
|
| 154 |
+
'backward': staticmethod(backward),
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
return generated_cls
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@functools.lru_cache
|
| 161 |
+
def namedtuple_args_cls(schema):
|
| 162 |
+
attribs = [arg.name for arg in schema.arguments.flat_all]
|
| 163 |
+
name = str(schema.name) + "_args"
|
| 164 |
+
# mypy doesn't support dynamic namedtuple name
|
| 165 |
+
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
|
| 166 |
+
return tuple_cls
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def namedtuple_args(schema, args):
|
| 170 |
+
assert isinstance(args, tuple)
|
| 171 |
+
tuple_cls = namedtuple_args_cls(schema)
|
| 172 |
+
return tuple_cls(*args)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
|
| 176 |
+
def error(what):
|
| 177 |
+
backward = forward_op._get_impl('backward')
|
| 178 |
+
raise RuntimeError(
|
| 179 |
+
f"In the backward function defined for {forward_op} at "
|
| 180 |
+
f"{backward.location} using the CustomOp API, {what}")
|
| 181 |
+
|
| 182 |
+
if not isinstance(grad_inputs_dict, dict):
|
| 183 |
+
error(f"expected the output of the backward function to be a dict but "
|
| 184 |
+
f"got {type(grad_inputs_dict)}")
|
| 185 |
+
|
| 186 |
+
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
|
| 187 |
+
if arg.type.is_tensor_like()}
|
| 188 |
+
actual_keys = grad_inputs_dict.keys()
|
| 189 |
+
if expected_keys != actual_keys:
|
| 190 |
+
error(f"expected the returned grad_input dict to have keys "
|
| 191 |
+
f"{expected_keys} but got {actual_keys}. The backward "
|
| 192 |
+
f"function must return a gradient (can be None) for each arg "
|
| 193 |
+
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
|
| 194 |
+
f"Args declared to be non-Tensor-like types should not appear "
|
| 195 |
+
f"in the grad_input dict")
|
| 196 |
+
|
| 197 |
+
for name, grad in grad_inputs_dict.items():
|
| 198 |
+
arg_info = getattr(args_info, name)
|
| 199 |
+
|
| 200 |
+
if isinstance(arg_info, list):
|
| 201 |
+
if not isinstance(grad, (tuple, list)):
|
| 202 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
| 203 |
+
f"hold a list of gradients but got object of type "
|
| 204 |
+
f"{type(grad)}.")
|
| 205 |
+
if not len(grad) == len(arg_info):
|
| 206 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
| 207 |
+
f"hold a list of {len(arg_info)} gradients but got "
|
| 208 |
+
f"{len(grad)}")
|
| 209 |
+
for idx, (g, info) in enumerate(zip(grad, arg_info)):
|
| 210 |
+
if g is None:
|
| 211 |
+
continue
|
| 212 |
+
if not isinstance(g, torch.Tensor):
|
| 213 |
+
error(f"for input '{name}' expected the grad_input dict to "
|
| 214 |
+
f"hold a list of None or Tensor gradients but got "
|
| 215 |
+
f"object of {type(g)} at index {idx}")
|
| 216 |
+
if not issubclass(info, torch.Tensor):
|
| 217 |
+
error(f"for input '{name}', got a Tensor as the gradient "
|
| 218 |
+
f"for the {idx}-th value but expected None because "
|
| 219 |
+
f"the {idx}-th value was not a Tensor (it was "
|
| 220 |
+
f"type {arg_info}")
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
+
if grad is None:
|
| 224 |
+
continue
|
| 225 |
+
if not isinstance(grad, torch.Tensor):
|
| 226 |
+
error(f"got object of type {type(grad)} as the gradient for input "
|
| 227 |
+
f"'{name}', "
|
| 228 |
+
f"but expected the gradient to be either None or a Tensor")
|
| 229 |
+
if not issubclass(arg_info, torch.Tensor):
|
| 230 |
+
error(f"got a Tensor as the gradient for input '{name}' but "
|
| 231 |
+
f"expected None as the gradient because input '{name}' "
|
| 232 |
+
f"was not a Tensor (it was type {arg_info}).")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
|
| 236 |
+
result = []
|
| 237 |
+
for name, arg_info in args_info._asdict().items():
|
| 238 |
+
if name not in grad_inputs_dict:
|
| 239 |
+
result.append(pytree.tree_map(lambda x: None, arg_info))
|
| 240 |
+
continue
|
| 241 |
+
result.append(grad_inputs_dict[name])
|
| 242 |
+
return tuple(pytree.tree_leaves(result))
|
| 243 |
+
|
| 244 |
+
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
|
| 245 |
+
# autograd.Function prefers that users use ctx.save_for_backward to
|
| 246 |
+
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
|
| 247 |
+
# ctx object.
|
| 248 |
+
def save_pytree_for_backward(ctx, stuff):
|
| 249 |
+
flat_stuff, spec = pytree.tree_flatten(stuff)
|
| 250 |
+
num_elts = len(flat_stuff)
|
| 251 |
+
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
| 252 |
+
if isinstance(thing, torch.Tensor)]
|
| 253 |
+
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
|
| 254 |
+
if not isinstance(thing, torch.Tensor)]
|
| 255 |
+
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
|
| 256 |
+
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
|
| 257 |
+
|
| 258 |
+
ctx.spec = spec
|
| 259 |
+
ctx.num_elts = num_elts
|
| 260 |
+
ctx.save_for_backward(*tensors)
|
| 261 |
+
ctx.tensor_idxs = tensor_idxs
|
| 262 |
+
ctx.saved_non_tensors = non_tensors
|
| 263 |
+
ctx.non_tensor_idxs = non_tensor_idxs
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# Inverse operation to save_pytree_for_backward
|
| 267 |
+
def unpack_saved(ctx):
|
| 268 |
+
flat_stuff = [None] * ctx.num_elts
|
| 269 |
+
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
|
| 270 |
+
flat_stuff[idx] = tensor
|
| 271 |
+
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
|
| 272 |
+
flat_stuff[idx] = non_tensor
|
| 273 |
+
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
|
| 274 |
+
return stuff
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__init__.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/__pycache__/_conversions.cpython-311.pyc
ADDED
|
Binary file (4.64 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_refs/nn/functional/__init__.py
ADDED
|
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from typing import Callable, Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch._prims as prims
|
| 7 |
+
import torch._prims_common as utils
|
| 8 |
+
import torch._refs as refs
|
| 9 |
+
from torch._decomp import register_decomposition
|
| 10 |
+
from torch._prims_common import (
|
| 11 |
+
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
| 12 |
+
NumberType,
|
| 13 |
+
ShapeType,
|
| 14 |
+
TensorLike,
|
| 15 |
+
TensorLikeType,
|
| 16 |
+
)
|
| 17 |
+
from torch._prims_common.wrappers import (
|
| 18 |
+
elementwise_type_promotion_wrapper,
|
| 19 |
+
elementwise_unary_scalar_wrapper,
|
| 20 |
+
out_wrapper,
|
| 21 |
+
)
|
| 22 |
+
from torch._refs import _make_inplace
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"alpha_dropout",
|
| 26 |
+
"celu",
|
| 27 |
+
"celu_",
|
| 28 |
+
"dropout",
|
| 29 |
+
"elu",
|
| 30 |
+
"elu_",
|
| 31 |
+
"gelu",
|
| 32 |
+
"glu",
|
| 33 |
+
"group_norm",
|
| 34 |
+
"hardshrink",
|
| 35 |
+
"hardtanh",
|
| 36 |
+
"hinge_embedding_loss",
|
| 37 |
+
"huber_loss",
|
| 38 |
+
"l1_loss",
|
| 39 |
+
"layer_norm",
|
| 40 |
+
"leaky_relu",
|
| 41 |
+
"log_softmax",
|
| 42 |
+
"margin_ranking_loss",
|
| 43 |
+
"mish",
|
| 44 |
+
"mish_",
|
| 45 |
+
"mse_loss",
|
| 46 |
+
"nll_loss",
|
| 47 |
+
"pairwise_distance",
|
| 48 |
+
"pdist",
|
| 49 |
+
"poisson_nll_loss",
|
| 50 |
+
"prelu",
|
| 51 |
+
"relu",
|
| 52 |
+
"relu6",
|
| 53 |
+
"selu",
|
| 54 |
+
"selu_",
|
| 55 |
+
"smooth_l1_loss",
|
| 56 |
+
"softmax",
|
| 57 |
+
"softmin",
|
| 58 |
+
"softplus",
|
| 59 |
+
"softshrink",
|
| 60 |
+
"tanhshrink",
|
| 61 |
+
"threshold",
|
| 62 |
+
"threshold_",
|
| 63 |
+
"triplet_margin_loss",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
Tensor = torch.Tensor
|
| 67 |
+
aten = torch._ops.ops.aten
|
| 68 |
+
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _dropout_helper(
|
| 72 |
+
self: TensorLikeType,
|
| 73 |
+
val: float,
|
| 74 |
+
) -> TensorLikeType:
|
| 75 |
+
"""
|
| 76 |
+
Helper function for all dropout-type operators. During training,
|
| 77 |
+
some of the elements of the input tensor are randomly masked.
|
| 78 |
+
|
| 79 |
+
Returns the masked tensor of the boolean values.
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
return (
|
| 84 |
+
refs._uniform_helper(
|
| 85 |
+
self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device
|
| 86 |
+
)
|
| 87 |
+
< val
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@register_decomposition(aten.alpha_dropout)
|
| 92 |
+
def alpha_dropout(
|
| 93 |
+
self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False
|
| 94 |
+
) -> TensorLikeType:
|
| 95 |
+
if inplace:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
if not training:
|
| 99 |
+
return self
|
| 100 |
+
|
| 101 |
+
torch._check(
|
| 102 |
+
p <= 1 and p >= 0,
|
| 103 |
+
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if p == 1:
|
| 107 |
+
return torch.zeros_like(self)
|
| 108 |
+
|
| 109 |
+
if p == 0:
|
| 110 |
+
return self
|
| 111 |
+
|
| 112 |
+
dropout_mask = _dropout_helper(self, 1 - p)
|
| 113 |
+
|
| 114 |
+
# From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf)
|
| 115 |
+
# alpha = - SELU.alpha * SELU.scale, here
|
| 116 |
+
# SELU.alpha = 1.6732632423543772848170429916717 and
|
| 117 |
+
# SELU.scale = 1.0507009873554804934193349852946
|
| 118 |
+
alpha = -1.7580993408473766
|
| 119 |
+
|
| 120 |
+
a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p))
|
| 121 |
+
b = torch.logical_not(dropout_mask)
|
| 122 |
+
b = b * (alpha * a) + alpha * a * p
|
| 123 |
+
dropout_mask = a * dropout_mask
|
| 124 |
+
|
| 125 |
+
return self * dropout_mask + b
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _inplace_wrapper(fn):
|
| 129 |
+
"""
|
| 130 |
+
Given a nn.functional non-linearity, implements its `inplace: bool` argument
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# nb. We use the name of the first argument used in the unary references
|
| 134 |
+
@wraps(fn)
|
| 135 |
+
def _fn(a, *args, inplace=False, **kwargs):
|
| 136 |
+
if inplace:
|
| 137 |
+
torch._check(
|
| 138 |
+
"out" not in kwargs,
|
| 139 |
+
lambda: "Cannot set inplace=True and pass out= at the same time",
|
| 140 |
+
)
|
| 141 |
+
return fn(a, *args, inplace=False, out=a, **kwargs)
|
| 142 |
+
else:
|
| 143 |
+
return fn(a, *args, inplace=False, **kwargs)
|
| 144 |
+
|
| 145 |
+
return _fn
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# celu is implemented specially because it has an alpha argument
|
| 149 |
+
# celu is very similar to elu
|
| 150 |
+
@register_decomposition(aten.celu)
|
| 151 |
+
@_inplace_wrapper
|
| 152 |
+
@out_wrapper()
|
| 153 |
+
@elementwise_type_promotion_wrapper(
|
| 154 |
+
type_promoting_args=("a",),
|
| 155 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 156 |
+
)
|
| 157 |
+
def celu(
|
| 158 |
+
a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False
|
| 159 |
+
) -> TensorLikeType:
|
| 160 |
+
"""
|
| 161 |
+
Reference implementation of torch.nn.functional.celu
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
if inplace:
|
| 165 |
+
raise NotImplementedError
|
| 166 |
+
|
| 167 |
+
rhs: TensorLikeType
|
| 168 |
+
if alpha is not None:
|
| 169 |
+
python_type = utils.dtype_to_type(a.dtype)
|
| 170 |
+
if not utils.is_weakly_lesser_type(type(alpha), python_type):
|
| 171 |
+
msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
|
| 172 |
+
raise ValueError(msg)
|
| 173 |
+
rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type]
|
| 174 |
+
else:
|
| 175 |
+
rhs = torch.expm1(a)
|
| 176 |
+
|
| 177 |
+
return torch.where(a > 0, a, rhs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@_inplace_wrapper
|
| 181 |
+
@out_wrapper()
|
| 182 |
+
def dropout(
|
| 183 |
+
a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False
|
| 184 |
+
) -> TensorLikeType:
|
| 185 |
+
if inplace:
|
| 186 |
+
raise NotImplementedError
|
| 187 |
+
|
| 188 |
+
if not training:
|
| 189 |
+
return a
|
| 190 |
+
|
| 191 |
+
torch._check(
|
| 192 |
+
p <= 1 and p >= 0,
|
| 193 |
+
lambda: f"dropout probability has to be between 0 and 1, but got, {p}",
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if p == 1:
|
| 197 |
+
return torch.zeros_like(a)
|
| 198 |
+
|
| 199 |
+
if p == 0:
|
| 200 |
+
return a
|
| 201 |
+
|
| 202 |
+
scale = 1 / (1 - p)
|
| 203 |
+
dropout_mask = _dropout_helper(a, 1 - p)
|
| 204 |
+
|
| 205 |
+
return a * dropout_mask * scale
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@register_decomposition(aten.elu)
|
| 209 |
+
@_inplace_wrapper
|
| 210 |
+
@out_wrapper()
|
| 211 |
+
@elementwise_type_promotion_wrapper(
|
| 212 |
+
type_promoting_args=("a",),
|
| 213 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 214 |
+
)
|
| 215 |
+
def elu(
|
| 216 |
+
a: TensorLikeType,
|
| 217 |
+
alpha: NumberType = 1.0,
|
| 218 |
+
scale: NumberType = 1.0,
|
| 219 |
+
input_scale: NumberType = 1.0,
|
| 220 |
+
inplace: bool = False,
|
| 221 |
+
) -> TensorLikeType:
|
| 222 |
+
"""
|
| 223 |
+
Reference implementation of torch.nn.functional.elu
|
| 224 |
+
"""
|
| 225 |
+
if inplace:
|
| 226 |
+
raise NotImplementedError
|
| 227 |
+
|
| 228 |
+
# nb. This should be factored out into a can_cast aux function
|
| 229 |
+
python_type = utils.dtype_to_type(a.dtype)
|
| 230 |
+
torch._check(
|
| 231 |
+
utils.is_weakly_lesser_type(type(input_scale), python_type),
|
| 232 |
+
lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!",
|
| 233 |
+
)
|
| 234 |
+
torch._check(
|
| 235 |
+
utils.is_weakly_lesser_type(type(scale), python_type),
|
| 236 |
+
lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!",
|
| 237 |
+
)
|
| 238 |
+
torch._check(
|
| 239 |
+
utils.is_weakly_lesser_type(type(alpha), python_type),
|
| 240 |
+
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@register_decomposition(aten.relu)
|
| 247 |
+
@_inplace_wrapper
|
| 248 |
+
@out_wrapper()
|
| 249 |
+
@elementwise_type_promotion_wrapper(
|
| 250 |
+
type_promoting_args=("a",),
|
| 251 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 252 |
+
)
|
| 253 |
+
def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
| 254 |
+
"""
|
| 255 |
+
Reference implementation of torch.nn.functional.relu
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
if inplace:
|
| 259 |
+
raise NotImplementedError
|
| 260 |
+
|
| 261 |
+
return torch.where(torch.le(a, 0), 0, a)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def group_norm(
|
| 265 |
+
input: Tensor,
|
| 266 |
+
num_groups: int,
|
| 267 |
+
weight: Optional[Tensor] = None,
|
| 268 |
+
bias: Optional[Tensor] = None,
|
| 269 |
+
eps: float = 1e-5,
|
| 270 |
+
) -> Tensor:
|
| 271 |
+
"""
|
| 272 |
+
Reference implementation of :func:`torch.nn.functional.group_norm`.
|
| 273 |
+
"""
|
| 274 |
+
torch._check(
|
| 275 |
+
input.ndim >= 2,
|
| 276 |
+
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
batch_size = input.shape[0]
|
| 280 |
+
num_channels = input.shape[1]
|
| 281 |
+
torch._check(
|
| 282 |
+
num_channels % num_groups == 0,
|
| 283 |
+
lambda: "Expected number of channels in input to be divisible by num_groups, "
|
| 284 |
+
+ f"but got input of shape {input.shape} and num_groups = {num_groups}",
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# input shape is (N, C, *), so we flatten all inner dimensions except (N, C)
|
| 288 |
+
flattened_inner_size = 1
|
| 289 |
+
for dim_length in input.shape[2:]:
|
| 290 |
+
flattened_inner_size *= dim_length
|
| 291 |
+
|
| 292 |
+
return torch.native_group_norm(
|
| 293 |
+
input,
|
| 294 |
+
weight,
|
| 295 |
+
bias,
|
| 296 |
+
batch_size,
|
| 297 |
+
num_channels,
|
| 298 |
+
flattened_inner_size,
|
| 299 |
+
num_groups,
|
| 300 |
+
eps,
|
| 301 |
+
)[0]
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def layer_norm(
|
| 305 |
+
input: Tensor,
|
| 306 |
+
normalized_shape: ShapeType,
|
| 307 |
+
weight: Optional[Tensor] = None,
|
| 308 |
+
bias: Optional[Tensor] = None,
|
| 309 |
+
eps: float = 1e-5,
|
| 310 |
+
) -> Tensor:
|
| 311 |
+
"""
|
| 312 |
+
Reference implementation of :func:`torch.nn.functional.layer_norm`.
|
| 313 |
+
"""
|
| 314 |
+
return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0]
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@register_decomposition(aten.leaky_relu)
|
| 318 |
+
@_inplace_wrapper
|
| 319 |
+
@out_wrapper()
|
| 320 |
+
@elementwise_type_promotion_wrapper(
|
| 321 |
+
type_promoting_args=("a",),
|
| 322 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 323 |
+
)
|
| 324 |
+
def leaky_relu(
|
| 325 |
+
a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False
|
| 326 |
+
) -> TensorLikeType:
|
| 327 |
+
"""
|
| 328 |
+
Reference implementation of torch.nn.functional.leaky_relu
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
if inplace:
|
| 332 |
+
raise NotImplementedError
|
| 333 |
+
|
| 334 |
+
python_type = utils.dtype_to_type(a.dtype)
|
| 335 |
+
if not utils.is_weakly_lesser_type(type(negative_slope), python_type):
|
| 336 |
+
msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!"
|
| 337 |
+
raise ValueError(msg)
|
| 338 |
+
return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope))
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@register_decomposition(aten.mish)
|
| 342 |
+
@_inplace_wrapper
|
| 343 |
+
@out_wrapper()
|
| 344 |
+
@elementwise_type_promotion_wrapper(
|
| 345 |
+
type_promoting_args=("a",),
|
| 346 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 347 |
+
)
|
| 348 |
+
def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
| 349 |
+
"""
|
| 350 |
+
Reference implementation of torch.nn.functional.mish
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
if inplace:
|
| 354 |
+
raise NotImplementedError
|
| 355 |
+
return a * torch.tanh(torch.nn.functional.softplus(a))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@register_decomposition(aten.selu)
|
| 359 |
+
@_inplace_wrapper
|
| 360 |
+
@out_wrapper()
|
| 361 |
+
@elementwise_type_promotion_wrapper(
|
| 362 |
+
type_promoting_args=("a",),
|
| 363 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 364 |
+
)
|
| 365 |
+
def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
| 366 |
+
"""
|
| 367 |
+
Reference implementation of torch.nn.functional.selu
|
| 368 |
+
"""
|
| 369 |
+
if inplace:
|
| 370 |
+
raise NotImplementedError
|
| 371 |
+
|
| 372 |
+
alpha = 1.6732632423543772848170429916717
|
| 373 |
+
scale = 1.0507009873554804934193349852946
|
| 374 |
+
|
| 375 |
+
rhs = alpha * torch.expm1(a)
|
| 376 |
+
|
| 377 |
+
return scale * torch.where(a > 0, a, rhs)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# Forwarding alias: the functional variant doesn't support the out kwarg
|
| 381 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 382 |
+
def softmax(
|
| 383 |
+
a: TensorLikeType,
|
| 384 |
+
dim: Optional[int] = None,
|
| 385 |
+
_stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
|
| 386 |
+
dtype: Optional[torch.dtype] = None,
|
| 387 |
+
) -> TensorLikeType:
|
| 388 |
+
# The error is for compat with regular PyTorch, which has this behavior
|
| 389 |
+
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
| 390 |
+
# behavior because it requires explicit opt in. This error is to inform
|
| 391 |
+
# users how to update their calls.
|
| 392 |
+
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
| 393 |
+
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 397 |
+
def softmin(
|
| 398 |
+
a: TensorLikeType,
|
| 399 |
+
dim: Optional[int] = None,
|
| 400 |
+
_stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
|
| 401 |
+
dtype: Optional[torch.dtype] = None,
|
| 402 |
+
) -> TensorLikeType:
|
| 403 |
+
# The error is for compat with regular PyTorch, which has this behavior
|
| 404 |
+
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
| 405 |
+
# behavior because it requires explicit opt in. This error is to inform
|
| 406 |
+
# users how to update their calls.
|
| 407 |
+
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
| 408 |
+
return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# softplus is implemented specially because it has beta and threshold arguments
|
| 412 |
+
@register_decomposition(aten.softplus)
|
| 413 |
+
@_inplace_wrapper
|
| 414 |
+
@out_wrapper()
|
| 415 |
+
@elementwise_type_promotion_wrapper(
|
| 416 |
+
type_promoting_args=("a",),
|
| 417 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 418 |
+
)
|
| 419 |
+
def softplus(
|
| 420 |
+
a: TensorLikeType,
|
| 421 |
+
beta: Optional[NumberType] = None,
|
| 422 |
+
threshold: NumberType = 20,
|
| 423 |
+
inplace: bool = False,
|
| 424 |
+
) -> TensorLikeType:
|
| 425 |
+
"""
|
| 426 |
+
Reference implementation of torch.nn.functional.softplus
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
if inplace:
|
| 430 |
+
raise NotImplementedError
|
| 431 |
+
|
| 432 |
+
rhs: TensorLikeType
|
| 433 |
+
if beta is not None:
|
| 434 |
+
python_type = utils.dtype_to_type(a.dtype)
|
| 435 |
+
if not utils.is_weakly_lesser_type(type(beta), python_type):
|
| 436 |
+
msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!"
|
| 437 |
+
raise ValueError(msg)
|
| 438 |
+
scaled_input = a * beta
|
| 439 |
+
rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type]
|
| 440 |
+
|
| 441 |
+
else:
|
| 442 |
+
scaled_input = a
|
| 443 |
+
rhs = torch.log1p(torch.exp(scaled_input))
|
| 444 |
+
|
| 445 |
+
return torch.where(scaled_input > threshold, a, rhs)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@aten.hardshrink.default.py_impl(DispatchKey.Autograd)
|
| 449 |
+
@register_decomposition(aten.hardshrink)
|
| 450 |
+
@out_wrapper()
|
| 451 |
+
def hardshrink(a: TensorLikeType, lambd: float = 0.5):
|
| 452 |
+
# Formula for reference,
|
| 453 |
+
# hardshrink(x) = x if x > lambd
|
| 454 |
+
# = x if x < -lambd
|
| 455 |
+
# = 0 otherwise
|
| 456 |
+
return torch.where(torch.abs(a) <= lambd, 0, a)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@aten.softshrink.default.py_impl(DispatchKey.Autograd)
|
| 460 |
+
@register_decomposition(aten.softshrink)
|
| 461 |
+
@out_wrapper()
|
| 462 |
+
def softshrink(a: TensorLikeType, lambd: float = 0.5):
|
| 463 |
+
# Formula for reference,
|
| 464 |
+
# softshrink(x) = x - lambd if x > lambd
|
| 465 |
+
# = x + lambd if x < -lambd
|
| 466 |
+
# = 0 otherwise
|
| 467 |
+
torch._check(
|
| 468 |
+
lambd >= 0,
|
| 469 |
+
lambda: f"lambda must be greater or equal to 0, but found to be {lambd}",
|
| 470 |
+
)
|
| 471 |
+
# We implement this in one torch.where to generate better code in the backward
|
| 472 |
+
# see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211
|
| 473 |
+
return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, 0)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
# Losses
|
| 477 |
+
def _reduction_int_to_str(reduction: int) -> str:
|
| 478 |
+
from torch._decomp.decompositions import Reduction
|
| 479 |
+
|
| 480 |
+
if reduction == Reduction.NONE.value:
|
| 481 |
+
return "none"
|
| 482 |
+
elif reduction == Reduction.MEAN.value:
|
| 483 |
+
return "mean"
|
| 484 |
+
elif reduction == Reduction.SUM.value:
|
| 485 |
+
return "sum"
|
| 486 |
+
else:
|
| 487 |
+
raise ValueError(f"{reduction} is not a valid value for reduction")
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType:
|
| 491 |
+
if reduction == "sum":
|
| 492 |
+
return torch.sum(loss)
|
| 493 |
+
elif reduction == "mean":
|
| 494 |
+
return torch.mean(loss)
|
| 495 |
+
else: # reduction == "none"
|
| 496 |
+
return loss
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _check_reduction_value(reduction: str):
|
| 500 |
+
if reduction not in ("mean", "sum", "none"):
|
| 501 |
+
raise ValueError(f"{reduction} is not a valid value for reduction")
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
# This helper function maps depreciated arguments, "size_average" and "reduce"
|
| 505 |
+
# to their corresponding "reduction" string argument
|
| 506 |
+
def _get_string_reduction_arg(
|
| 507 |
+
*, size_average: Optional[bool], reduce: Optional[bool]
|
| 508 |
+
) -> str:
|
| 509 |
+
if size_average is None:
|
| 510 |
+
size_average = True
|
| 511 |
+
if reduce is None:
|
| 512 |
+
reduce = True
|
| 513 |
+
if size_average and reduce:
|
| 514 |
+
ret = "mean"
|
| 515 |
+
elif reduce:
|
| 516 |
+
ret = "sum"
|
| 517 |
+
else:
|
| 518 |
+
ret = "none"
|
| 519 |
+
return ret
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 523 |
+
@elementwise_type_promotion_wrapper(
|
| 524 |
+
type_promoting_args=("input", "target"),
|
| 525 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
| 526 |
+
)
|
| 527 |
+
def l1_loss(
|
| 528 |
+
input: TensorLikeType,
|
| 529 |
+
target: TensorLikeType,
|
| 530 |
+
size_average: Optional[bool] = None,
|
| 531 |
+
reduce: Optional[bool] = None,
|
| 532 |
+
reduction: str = "mean",
|
| 533 |
+
) -> TensorLikeType:
|
| 534 |
+
"""
|
| 535 |
+
Reference implementation of torch.nn.functional.l1_loss
|
| 536 |
+
"""
|
| 537 |
+
if size_average is not None or reduce is not None:
|
| 538 |
+
# TODO: Raise exception instead of converting value. This is only for
|
| 539 |
+
# primTorch since it can drop support for deprecated arguments.
|
| 540 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 541 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 542 |
+
_check_reduction_value(reduction)
|
| 543 |
+
loss = torch.abs(input - target)
|
| 544 |
+
return _apply_loss_reduction(loss, reduction)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
@elementwise_type_promotion_wrapper(
|
| 548 |
+
type_promoting_args=("input", "target"),
|
| 549 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
| 550 |
+
)
|
| 551 |
+
def smooth_l1_loss(
|
| 552 |
+
input: TensorLikeType,
|
| 553 |
+
target: TensorLikeType,
|
| 554 |
+
size_average: Optional[bool] = None,
|
| 555 |
+
reduce: Optional[bool] = None,
|
| 556 |
+
reduction: str = "mean",
|
| 557 |
+
beta: float = 1.0,
|
| 558 |
+
) -> TensorLikeType:
|
| 559 |
+
"""
|
| 560 |
+
Reference implementation of torch.nn.functional.smooth_l1_loss
|
| 561 |
+
"""
|
| 562 |
+
if size_average is not None or reduce is not None:
|
| 563 |
+
# TODO: Raise exception instead of converting value. This is only for
|
| 564 |
+
# primTorch since it can drop support for deprecated arguments.
|
| 565 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 566 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 567 |
+
_check_reduction_value(reduction)
|
| 568 |
+
|
| 569 |
+
if beta == 0.0:
|
| 570 |
+
return torch.nn.functional.l1_loss(
|
| 571 |
+
input, target, size_average=size_average, reduce=reduce, reduction=reduction
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
loss = torch.abs(input - target)
|
| 575 |
+
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
|
| 576 |
+
return _apply_loss_reduction(loss, reduction)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
# Forwarding alias: the functional variant doesn't support the out kwarg
|
| 580 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 581 |
+
def log_softmax(
|
| 582 |
+
a: TensorLikeType,
|
| 583 |
+
dim: Optional[int] = None,
|
| 584 |
+
_stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True)
|
| 585 |
+
dtype: Optional[torch.dtype] = None,
|
| 586 |
+
) -> TensorLikeType:
|
| 587 |
+
# The error is for compat with regular PyTorch, which has this behavior
|
| 588 |
+
# deprecated. For PrimTorch, it's fine to drop support for deprecated
|
| 589 |
+
# behavior because it requires explicit opt in. This error is to inform
|
| 590 |
+
# users how to update their calls.
|
| 591 |
+
torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X")
|
| 592 |
+
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
@register_decomposition(aten.margin_ranking_loss)
|
| 596 |
+
def margin_ranking_loss(
|
| 597 |
+
input1: TensorLikeType,
|
| 598 |
+
input2: TensorLikeType,
|
| 599 |
+
target: TensorLikeType,
|
| 600 |
+
margin: float = 0.0,
|
| 601 |
+
reduction: str = "mean",
|
| 602 |
+
) -> TensorLikeType:
|
| 603 |
+
# loss_without_reduction = max(0, −target * (input1 − input2) + margin)
|
| 604 |
+
if input1.ndim != input2.ndim or input1.ndim != target.ndim:
|
| 605 |
+
raise RuntimeError(
|
| 606 |
+
"margin_ranking_loss : All input tensors should have same dimension but got sizes: "
|
| 607 |
+
f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} "
|
| 608 |
+
)
|
| 609 |
+
_check_reduction_value(reduction)
|
| 610 |
+
loss = torch.clamp_min(-target * (input1 - input2) + margin, 0)
|
| 611 |
+
return _apply_loss_reduction(loss, reduction)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@elementwise_type_promotion_wrapper(
|
| 615 |
+
type_promoting_args=("input", "target"),
|
| 616 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
| 617 |
+
)
|
| 618 |
+
def mse_loss(
|
| 619 |
+
input: TensorLikeType,
|
| 620 |
+
target: TensorLikeType,
|
| 621 |
+
size_average: Optional[bool] = None,
|
| 622 |
+
reduce: Optional[bool] = None,
|
| 623 |
+
reduction: str = "mean",
|
| 624 |
+
) -> TensorLikeType:
|
| 625 |
+
if size_average is not None or reduce is not None:
|
| 626 |
+
# TODO: Raise exception instead of converting value. This is only for
|
| 627 |
+
# primTorch since it can drop support for deprecated arguments.
|
| 628 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 629 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 630 |
+
_check_reduction_value(reduction)
|
| 631 |
+
loss = torch.pow(input - target, 2)
|
| 632 |
+
return _apply_loss_reduction(loss, reduction)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@register_decomposition(aten.hinge_embedding_loss)
|
| 636 |
+
def hinge_embedding_loss(
|
| 637 |
+
input: TensorLikeType,
|
| 638 |
+
target: TensorLikeType,
|
| 639 |
+
margin: float = 1.0,
|
| 640 |
+
reduction: str = "mean",
|
| 641 |
+
) -> TensorLikeType:
|
| 642 |
+
# loss_without_reduction = input if y == 1
|
| 643 |
+
# = max(0, margin - input) if y == -1
|
| 644 |
+
_check_reduction_value(reduction)
|
| 645 |
+
margin_clamp = torch.clamp_min(margin - input, 0)
|
| 646 |
+
output_margin = torch.where(target != 1, margin_clamp, 0)
|
| 647 |
+
output_self = torch.where(target != -1, input, 0)
|
| 648 |
+
loss = output_margin + output_self
|
| 649 |
+
return _apply_loss_reduction(loss, reduction)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def _nll_loss_nd(
|
| 653 |
+
input: TensorLikeType,
|
| 654 |
+
target: TensorLikeType,
|
| 655 |
+
weight: Optional[TensorLikeType],
|
| 656 |
+
reduction: str,
|
| 657 |
+
ignore_index: int,
|
| 658 |
+
) -> TensorLikeType:
|
| 659 |
+
torch._check(
|
| 660 |
+
input.ndim > 0 and input.ndim <= 3,
|
| 661 |
+
lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.",
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
torch._check(
|
| 665 |
+
(input.ndim == 1) or (input.shape[0] == target.shape[0]),
|
| 666 |
+
lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.",
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
_check_reduction_value(reduction)
|
| 670 |
+
|
| 671 |
+
flat_target = torch.flatten(target)
|
| 672 |
+
ignore_classes_mask = torch.eq(flat_target, ignore_index)
|
| 673 |
+
|
| 674 |
+
# TODO: Enable data-dependent checks with debug mode
|
| 675 |
+
# TODO: This check does not work with FakeTensor inputs; See Issue #85834
|
| 676 |
+
# Explicit cast for class_check to bool; See Issue #78071
|
| 677 |
+
"""
|
| 678 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 679 |
+
num_classes = input.shape[1] if input.ndim > 1 else input.shape[0]
|
| 680 |
+
valid_classes_mask = torch.logical_and(
|
| 681 |
+
(flat_target >= 0), (flat_target < num_classes)
|
| 682 |
+
)
|
| 683 |
+
class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask))
|
| 684 |
+
torch._check(
|
| 685 |
+
isinstance(target, FakeTensor) or bool(class_check.item()),
|
| 686 |
+
lambda: "A target class is out-of-bounds and not the ignore index.",
|
| 687 |
+
)
|
| 688 |
+
"""
|
| 689 |
+
|
| 690 |
+
ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device)
|
| 691 |
+
class_weight = (
|
| 692 |
+
torch.scalar_tensor(1, dtype=input.dtype, device=input.device)
|
| 693 |
+
if weight is None
|
| 694 |
+
else weight[flat_target]
|
| 695 |
+
)
|
| 696 |
+
current_weight = torch.where(
|
| 697 |
+
ignore_classes_mask,
|
| 698 |
+
ignore_class_weight,
|
| 699 |
+
class_weight,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
if input.ndim == 1:
|
| 703 |
+
# implicit batch size = 1
|
| 704 |
+
# input (1 batch size, C classes)
|
| 705 |
+
loss = -input[target] * current_weight
|
| 706 |
+
elif input.ndim == 2:
|
| 707 |
+
# input (N batch size, C classes)
|
| 708 |
+
batch_size = input.shape[0]
|
| 709 |
+
loss = -input[torch.arange(batch_size), target] * current_weight
|
| 710 |
+
else:
|
| 711 |
+
# 3D case (N batch size, C classe, K dimensions)
|
| 712 |
+
# input (N batch size, C classes, K)
|
| 713 |
+
batch_size = input.shape[0]
|
| 714 |
+
extent = input.shape[2]
|
| 715 |
+
numel = batch_size * extent
|
| 716 |
+
indices = torch.arange(numel)
|
| 717 |
+
bdx = indices // extent
|
| 718 |
+
kdx = indices % extent
|
| 719 |
+
loss = -input[bdx, flat_target, kdx] * current_weight
|
| 720 |
+
loss = torch.reshape(loss, target.shape)
|
| 721 |
+
|
| 722 |
+
if reduction == "none":
|
| 723 |
+
return loss
|
| 724 |
+
elif reduction == "sum":
|
| 725 |
+
return torch.sum(loss)
|
| 726 |
+
else:
|
| 727 |
+
# calculate weighted mean of the loss function
|
| 728 |
+
return torch.sum(loss) / torch.sum(current_weight)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
@register_decomposition(aten.nll_loss)
|
| 732 |
+
@out_wrapper()
|
| 733 |
+
@elementwise_type_promotion_wrapper(
|
| 734 |
+
type_promoting_args=("input",),
|
| 735 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 736 |
+
)
|
| 737 |
+
def nll_loss(
|
| 738 |
+
input: TensorLikeType,
|
| 739 |
+
target: TensorLikeType,
|
| 740 |
+
weight: Optional[TensorLikeType] = None,
|
| 741 |
+
size_average: Optional[bool] = None,
|
| 742 |
+
ignore_index: int = -100,
|
| 743 |
+
reduce: Optional[bool] = None,
|
| 744 |
+
reduction: str = "mean",
|
| 745 |
+
) -> TensorLikeType:
|
| 746 |
+
"""
|
| 747 |
+
Reference implementation of torch.nn.functional.nll_loss
|
| 748 |
+
"""
|
| 749 |
+
torch._check(
|
| 750 |
+
input.ndim > 0,
|
| 751 |
+
lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})",
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# TODO: raise exception instead of converting value
|
| 755 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 756 |
+
# Convert these options for consistency with the eager mode
|
| 757 |
+
if size_average is not None or reduce is not None:
|
| 758 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 759 |
+
|
| 760 |
+
# The expected behavior when the target and input have zero elements:
|
| 761 |
+
# reduction = 'none' --- tensor([])
|
| 762 |
+
# reduction = 'sum' --- tensor(0.)
|
| 763 |
+
# reduction = 'mean' --- tensor(nan)
|
| 764 |
+
# Mean reduction on empty tensors produces NaN. See the discussion in
|
| 765 |
+
# https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
|
| 766 |
+
if input.numel() == 0 and target.numel() == 0:
|
| 767 |
+
if reduction == "none":
|
| 768 |
+
return torch.zeros_like(target)
|
| 769 |
+
elif reduction == "sum":
|
| 770 |
+
return torch.empty_like(target)
|
| 771 |
+
else:
|
| 772 |
+
return torch.full_like(target, float("nan"))
|
| 773 |
+
|
| 774 |
+
# The _nll_loss_nd helper function handles the most common cases.
|
| 775 |
+
# ndim == 1 (Single Example)
|
| 776 |
+
# => Batch Size: 1, Input: (C), Target: ()
|
| 777 |
+
# ndim == 2 (k = 1)
|
| 778 |
+
# => Batch Size: N, Input: (N, C), Target: (N)
|
| 779 |
+
# ndim == 3 (k > 1)
|
| 780 |
+
# => Batch Size: N, Input: (N, C, K), Target: (N, K)
|
| 781 |
+
if input.ndim <= 3:
|
| 782 |
+
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
|
| 783 |
+
|
| 784 |
+
# For ndim > 3, we reshape the input and target to 3-D case.
|
| 785 |
+
# Input (N batch-size, C classes, k-dimensions)
|
| 786 |
+
# Target (N batch-size, k-dimensions)
|
| 787 |
+
torch._check(
|
| 788 |
+
input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:],
|
| 789 |
+
lambda: (
|
| 790 |
+
"Expected input and target to both have ndim > 0 and "
|
| 791 |
+
"target.shape[1:] == input.shape[2:], but got "
|
| 792 |
+
f"target.shape {target.shape} and input.shape {input.shape}"
|
| 793 |
+
),
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
batch_size = input.shape[0]
|
| 797 |
+
num_classes = input.shape[1]
|
| 798 |
+
out_size = [batch_size] + list(target.shape[1:])
|
| 799 |
+
|
| 800 |
+
input = torch.reshape(input, [batch_size, num_classes, -1])
|
| 801 |
+
target = torch.reshape(target, [batch_size, -1])
|
| 802 |
+
if reduction != "none":
|
| 803 |
+
return _nll_loss_nd(input, target, weight, reduction, ignore_index)
|
| 804 |
+
else:
|
| 805 |
+
result = _nll_loss_nd(input, target, weight, reduction, ignore_index)
|
| 806 |
+
# reshape flattened inner-dim to original k-dimensions
|
| 807 |
+
return torch.reshape(result, out_size)
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
# TODO: This ref supports int reduction and out kwarg to be compatible with ATen:
|
| 811 |
+
# https://github.com/pytorch/pytorch/issues/83931
|
| 812 |
+
# TODO: Could be rewritten to support complex:
|
| 813 |
+
# https://github.com/pytorch/pytorch/pull/85041
|
| 814 |
+
@register_decomposition(aten.huber_loss)
|
| 815 |
+
@out_wrapper()
|
| 816 |
+
@elementwise_type_promotion_wrapper(
|
| 817 |
+
type_promoting_args=("input", "target"),
|
| 818 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 819 |
+
)
|
| 820 |
+
def huber_loss(
|
| 821 |
+
input: TensorLikeType,
|
| 822 |
+
target: TensorLikeType,
|
| 823 |
+
reduction: Union[str, int] = "mean",
|
| 824 |
+
delta: float = 1.0,
|
| 825 |
+
) -> TensorLikeType:
|
| 826 |
+
"""
|
| 827 |
+
Reference implementation of torch.nn.functional.huber_loss
|
| 828 |
+
"""
|
| 829 |
+
if type(reduction) is int:
|
| 830 |
+
reduction = _reduction_int_to_str(reduction)
|
| 831 |
+
_check_reduction_value(reduction) # type: ignore[arg-type]
|
| 832 |
+
torch._check(
|
| 833 |
+
delta > 0,
|
| 834 |
+
lambda: "huber_loss does not support non-positive values for delta.",
|
| 835 |
+
)
|
| 836 |
+
z = (input - target).abs()
|
| 837 |
+
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
|
| 838 |
+
return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type]
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
# tanhshrink does not use _make_elementwise_unary_reference because it does not support out
|
| 842 |
+
@elementwise_unary_scalar_wrapper
|
| 843 |
+
@elementwise_type_promotion_wrapper(
|
| 844 |
+
type_promoting_args=("a",),
|
| 845 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 846 |
+
)
|
| 847 |
+
def tanhshrink(a: TensorLikeType) -> TensorLikeType:
|
| 848 |
+
"""
|
| 849 |
+
Reference implementation of torch.nn.functional.tanhshrink
|
| 850 |
+
"""
|
| 851 |
+
if not isinstance(a, TensorLike):
|
| 852 |
+
raise RuntimeError(
|
| 853 |
+
"Expected a tensor input for an elementwise unary operation!"
|
| 854 |
+
)
|
| 855 |
+
return a - torch.tanh(a)
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
@register_decomposition(aten.threshold)
|
| 859 |
+
@_inplace_wrapper
|
| 860 |
+
@out_wrapper()
|
| 861 |
+
@elementwise_type_promotion_wrapper(
|
| 862 |
+
type_promoting_args=("a",),
|
| 863 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 864 |
+
)
|
| 865 |
+
def threshold(
|
| 866 |
+
a: TensorLikeType,
|
| 867 |
+
threshold: NumberType,
|
| 868 |
+
value: Union[bool, int, float],
|
| 869 |
+
inplace: bool = False,
|
| 870 |
+
) -> TensorLikeType:
|
| 871 |
+
"""
|
| 872 |
+
Reference implementation of torch.nn.functional.threshold
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
if inplace:
|
| 876 |
+
raise NotImplementedError
|
| 877 |
+
|
| 878 |
+
return torch.where(a <= threshold, value, a)
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 882 |
+
# No elementwise type promotion - core op doesn't explicitly type promote
|
| 883 |
+
def triplet_margin_loss(
|
| 884 |
+
anchor: TensorLikeType,
|
| 885 |
+
positive: TensorLikeType,
|
| 886 |
+
negative: TensorLikeType,
|
| 887 |
+
margin: float = 1.0,
|
| 888 |
+
p: float = 2,
|
| 889 |
+
eps: float = 1e-6,
|
| 890 |
+
swap: bool = False,
|
| 891 |
+
size_average: Optional[bool] = None,
|
| 892 |
+
reduce: Optional[bool] = None,
|
| 893 |
+
reduction: str = "mean",
|
| 894 |
+
) -> TensorLikeType:
|
| 895 |
+
if size_average is not None or reduce is not None:
|
| 896 |
+
# TODO: Raise exception instead of converting value. This is only for
|
| 897 |
+
# primTorch since it can drop support for deprecated arguments.
|
| 898 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 899 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 900 |
+
|
| 901 |
+
# torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
|
| 902 |
+
# since it's a pure Python implementation. Use this helper instead.
|
| 903 |
+
return _triplet_margin_with_distance_loss(
|
| 904 |
+
anchor=anchor,
|
| 905 |
+
positive=positive,
|
| 906 |
+
negative=negative,
|
| 907 |
+
distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
|
| 908 |
+
margin=margin,
|
| 909 |
+
swap=swap,
|
| 910 |
+
reduction=reduction,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
# Pure Python impl - don't register decomp and don't add a ref. Defined as a
|
| 915 |
+
# helper here since triplet_margin_loss can be nicely implemented with it.
|
| 916 |
+
def _triplet_margin_with_distance_loss(
|
| 917 |
+
anchor: TensorLikeType,
|
| 918 |
+
positive: TensorLikeType,
|
| 919 |
+
negative: TensorLikeType,
|
| 920 |
+
*,
|
| 921 |
+
distance_function: Optional[
|
| 922 |
+
Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
|
| 923 |
+
] = None,
|
| 924 |
+
margin: float = 1.0,
|
| 925 |
+
swap: bool = False,
|
| 926 |
+
reduction: str = "mean",
|
| 927 |
+
) -> TensorLikeType:
|
| 928 |
+
_check_reduction_value(reduction)
|
| 929 |
+
|
| 930 |
+
a_dim = anchor.ndim
|
| 931 |
+
p_dim = positive.ndim
|
| 932 |
+
n_dim = negative.ndim
|
| 933 |
+
torch._check(
|
| 934 |
+
a_dim == p_dim and p_dim == n_dim,
|
| 935 |
+
lambda: (
|
| 936 |
+
f"The anchor, positive, and negative tensors are expected to have "
|
| 937 |
+
f"the same number of dimensions, but got: anchor {a_dim}D, "
|
| 938 |
+
f"positive {p_dim}D, and negative {n_dim}D inputs"
|
| 939 |
+
),
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
if distance_function is None:
|
| 943 |
+
distance_function = torch.pairwise_distance
|
| 944 |
+
|
| 945 |
+
dist_pos = distance_function(anchor, positive)
|
| 946 |
+
dist_neg = distance_function(anchor, negative)
|
| 947 |
+
# The distance swap is described in the paper "Learning shallow
|
| 948 |
+
# convolutional feature descriptors with triplet losses" by V. Balntas, E.
|
| 949 |
+
# Riba et al. If True, and if the positive example is closer to the
|
| 950 |
+
# negative example than the anchor is, swaps the positive example and the
|
| 951 |
+
# anchor in the loss computation.
|
| 952 |
+
if swap:
|
| 953 |
+
dist_swap = distance_function(positive, negative)
|
| 954 |
+
dist_neg = torch.minimum(dist_neg, dist_swap)
|
| 955 |
+
loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
|
| 956 |
+
return _apply_loss_reduction(loss, reduction)
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
@register_decomposition(aten.hardtanh)
|
| 960 |
+
@_inplace_wrapper
|
| 961 |
+
@out_wrapper()
|
| 962 |
+
@elementwise_unary_scalar_wrapper
|
| 963 |
+
@elementwise_type_promotion_wrapper(
|
| 964 |
+
type_promoting_args=("a"),
|
| 965 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 966 |
+
)
|
| 967 |
+
def hardtanh(
|
| 968 |
+
a: TensorLikeType,
|
| 969 |
+
min_val: NumberType = -1,
|
| 970 |
+
max_val: NumberType = 1,
|
| 971 |
+
inplace: bool = False,
|
| 972 |
+
) -> TensorLikeType:
|
| 973 |
+
"""
|
| 974 |
+
Reference implementation of torch.nn.functional.hardtanh
|
| 975 |
+
"""
|
| 976 |
+
if inplace:
|
| 977 |
+
raise NotImplementedError
|
| 978 |
+
if utils.is_boolean_dtype(a.dtype):
|
| 979 |
+
raise RuntimeError("Bool inputs not supported for hardtanh")
|
| 980 |
+
|
| 981 |
+
# preserve legacy behavior of boundaries not causing type promotion
|
| 982 |
+
if utils.is_integer_dtype(a.dtype):
|
| 983 |
+
min_val = int(min_val) # type: ignore[arg-type]
|
| 984 |
+
max_val = int(max_val) # type: ignore[arg-type]
|
| 985 |
+
if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)):
|
| 986 |
+
raise RuntimeError(
|
| 987 |
+
"Cannot do hardtanh on an unsigned type with negative limits"
|
| 988 |
+
)
|
| 989 |
+
return torch.clamp(a, min_val, max_val) # type: ignore[arg-type]
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
@register_decomposition(aten.gelu)
|
| 993 |
+
@out_wrapper()
|
| 994 |
+
@elementwise_unary_scalar_wrapper
|
| 995 |
+
@elementwise_type_promotion_wrapper(
|
| 996 |
+
type_promoting_args=("a",),
|
| 997 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 998 |
+
)
|
| 999 |
+
def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType:
|
| 1000 |
+
"""
|
| 1001 |
+
Reference implementation of torch.nn.functional.gelu
|
| 1002 |
+
"""
|
| 1003 |
+
if not isinstance(a, TensorLike):
|
| 1004 |
+
raise RuntimeError(
|
| 1005 |
+
"Expected a tensor input for an elementwise unary operation!"
|
| 1006 |
+
)
|
| 1007 |
+
M_SQRT2 = 1.41421356237309504880
|
| 1008 |
+
M_SQRT1_2 = 0.70710678118654752440
|
| 1009 |
+
M_2_SQRTPI = 1.12837916709551257390
|
| 1010 |
+
if approximate == "tanh":
|
| 1011 |
+
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
|
| 1012 |
+
kKappa = 0.044715
|
| 1013 |
+
a_cube = a * a * a
|
| 1014 |
+
inner = kBeta * (a + kKappa * a_cube)
|
| 1015 |
+
return 0.5 * a * (1 + torch.tanh(inner))
|
| 1016 |
+
elif approximate == "none":
|
| 1017 |
+
kAlpha = M_SQRT1_2
|
| 1018 |
+
return a * 0.5 * (1 + torch.erf(a * kAlpha))
|
| 1019 |
+
else:
|
| 1020 |
+
raise RuntimeError("approximate argument must be either none or tanh.")
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
# CompositeImplicitAutograd - don't register decomp
|
| 1024 |
+
@elementwise_type_promotion_wrapper(
|
| 1025 |
+
type_promoting_args=("input", "target"),
|
| 1026 |
+
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
| 1027 |
+
)
|
| 1028 |
+
def poisson_nll_loss(
|
| 1029 |
+
input: TensorLikeType,
|
| 1030 |
+
target: TensorLikeType,
|
| 1031 |
+
log_input: bool = True,
|
| 1032 |
+
full: bool = False,
|
| 1033 |
+
size_average: Optional[bool] = None,
|
| 1034 |
+
eps: float = 1e-8,
|
| 1035 |
+
reduce: Optional[bool] = None,
|
| 1036 |
+
reduction: str = "mean",
|
| 1037 |
+
) -> TensorLikeType:
|
| 1038 |
+
"""
|
| 1039 |
+
Reference implementation of torch.nn.functional.poisson_nll_loss
|
| 1040 |
+
"""
|
| 1041 |
+
if size_average is not None or reduce is not None:
|
| 1042 |
+
# TODO: Raise exception instead of converting value. This is only for
|
| 1043 |
+
# primTorch since it can drop support for deprecated arguments.
|
| 1044 |
+
# msg = "size_average and reduce args are deprecated, please use reduction argument."
|
| 1045 |
+
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
|
| 1046 |
+
_check_reduction_value(reduction)
|
| 1047 |
+
if log_input:
|
| 1048 |
+
loss = torch.exp(input) - target * input
|
| 1049 |
+
else:
|
| 1050 |
+
loss = input - target * torch.log(input + eps)
|
| 1051 |
+
|
| 1052 |
+
if full:
|
| 1053 |
+
stirling_term = (
|
| 1054 |
+
target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target)
|
| 1055 |
+
)
|
| 1056 |
+
# avoid inplace add
|
| 1057 |
+
loss = loss + stirling_term.masked_fill(target <= 1, 0)
|
| 1058 |
+
return _apply_loss_reduction(loss, reduction)
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
@register_decomposition(aten.prelu)
|
| 1062 |
+
@elementwise_type_promotion_wrapper(
|
| 1063 |
+
type_promoting_args=("a", "weight"),
|
| 1064 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 1065 |
+
)
|
| 1066 |
+
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
|
| 1067 |
+
"""
|
| 1068 |
+
Reference implementation of torch.nn.functional.prelu
|
| 1069 |
+
"""
|
| 1070 |
+
torch._check(
|
| 1071 |
+
isinstance(a, TensorLike),
|
| 1072 |
+
lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
|
| 1073 |
+
)
|
| 1074 |
+
torch._check(
|
| 1075 |
+
isinstance(weight, TensorLike),
|
| 1076 |
+
lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
if weight.numel() != 1:
|
| 1080 |
+
torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
|
| 1081 |
+
channel_size = a.shape[1] if a.ndim >= 2 else 1
|
| 1082 |
+
torch._check(
|
| 1083 |
+
weight.numel() == channel_size,
|
| 1084 |
+
lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
|
| 1085 |
+
f" {weight.numel()} and channel size = {channel_size}.",
|
| 1086 |
+
)
|
| 1087 |
+
|
| 1088 |
+
torch._check(
|
| 1089 |
+
weight.ndim == 0 or weight.ndim == 1,
|
| 1090 |
+
lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
|
| 1091 |
+
f"ndim = {weight.ndim}",
|
| 1092 |
+
)
|
| 1093 |
+
if a.ndim == 0:
|
| 1094 |
+
weight = weight[0] if weight.ndim == 1 else weight
|
| 1095 |
+
else:
|
| 1096 |
+
weight = prims.broadcast_in_dim(
|
| 1097 |
+
weight, a.shape, tuple() if weight.ndim == 0 else (0 if a.ndim == 1 else 1,)
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
return torch.where(a > 0, a, a * weight)
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
@register_decomposition(aten.relu6)
|
| 1104 |
+
@_inplace_wrapper
|
| 1105 |
+
@out_wrapper()
|
| 1106 |
+
def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType:
|
| 1107 |
+
"""
|
| 1108 |
+
Reference implementation of torch.nn.functional.relu6
|
| 1109 |
+
"""
|
| 1110 |
+
if inplace:
|
| 1111 |
+
raise NotImplementedError
|
| 1112 |
+
|
| 1113 |
+
# See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126
|
| 1114 |
+
# It may be better to use clamp here, but we use hardtanh to replicate
|
| 1115 |
+
# the behavior of the existing implementation
|
| 1116 |
+
return torch.nn.functional.hardtanh(a, 0, 6)
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
@register_decomposition(aten.glu)
|
| 1120 |
+
@out_wrapper()
|
| 1121 |
+
@elementwise_type_promotion_wrapper(
|
| 1122 |
+
type_promoting_args=("a",),
|
| 1123 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 1124 |
+
)
|
| 1125 |
+
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
|
| 1126 |
+
dim = utils.canonicalize_dims(a.ndim, dim)
|
| 1127 |
+
torch._check(
|
| 1128 |
+
a.shape[dim] % 2 == 0,
|
| 1129 |
+
lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
|
| 1130 |
+
)
|
| 1131 |
+
b, c = torch.tensor_split(a, 2, dim)
|
| 1132 |
+
|
| 1133 |
+
return b * torch.sigmoid(c)
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
@register_decomposition(aten.pairwise_distance)
|
| 1137 |
+
@out_wrapper()
|
| 1138 |
+
def pairwise_distance(
|
| 1139 |
+
x1: TensorLikeType,
|
| 1140 |
+
x2: TensorLikeType,
|
| 1141 |
+
p: NumberType = 2.0,
|
| 1142 |
+
eps: NumberType = 1e-6,
|
| 1143 |
+
keepdim=False,
|
| 1144 |
+
) -> TensorLikeType:
|
| 1145 |
+
return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim)
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
@register_decomposition(aten.pdist)
|
| 1149 |
+
@out_wrapper()
|
| 1150 |
+
@elementwise_type_promotion_wrapper(
|
| 1151 |
+
type_promoting_args=("a",),
|
| 1152 |
+
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
| 1153 |
+
)
|
| 1154 |
+
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
|
| 1155 |
+
torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
|
| 1156 |
+
torch._check(p >= 0, lambda: "pdist only supports non-negative p values")
|
| 1157 |
+
# For p == 2 we can use an efficient implementation, but other values of p
|
| 1158 |
+
# require creating a much bigger tensor for an intermediate step
|
| 1159 |
+
if p == 2:
|
| 1160 |
+
aTa = torch.mm(a, a.T)
|
| 1161 |
+
aTa_diag = torch.diag(aTa)
|
| 1162 |
+
t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0))
|
| 1163 |
+
else:
|
| 1164 |
+
t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2)
|
| 1165 |
+
i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device)
|
| 1166 |
+
return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
@register_decomposition(aten.pixel_shuffle)
|
| 1170 |
+
@out_wrapper()
|
| 1171 |
+
def pixel_shuffle(self: Tensor, upscale_factor: int):
|
| 1172 |
+
torch._check(
|
| 1173 |
+
self.dim() >= 3,
|
| 1174 |
+
lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
|
| 1175 |
+
)
|
| 1176 |
+
batch = self.shape[:-3]
|
| 1177 |
+
C_out = self.shape[-3] // upscale_factor**2
|
| 1178 |
+
HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor)
|
| 1179 |
+
n = len(batch)
|
| 1180 |
+
B_dims = range(n)
|
| 1181 |
+
C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5)
|
| 1182 |
+
return (
|
| 1183 |
+
self.view(
|
| 1184 |
+
*batch,
|
| 1185 |
+
C_out,
|
| 1186 |
+
upscale_factor,
|
| 1187 |
+
upscale_factor,
|
| 1188 |
+
self.shape[-2],
|
| 1189 |
+
self.shape[-1],
|
| 1190 |
+
)
|
| 1191 |
+
.permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim)
|
| 1192 |
+
.reshape(*batch, C_out, *HW_out)
|
| 1193 |
+
.clone(memory_format=utils.suggest_memory_format(self))
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
|
| 1197 |
+
@register_decomposition(aten.pixel_unshuffle)
|
| 1198 |
+
@out_wrapper()
|
| 1199 |
+
def pixel_unshuffle(self: Tensor, downscale_factor: int):
|
| 1200 |
+
torch._check(
|
| 1201 |
+
self.dim() >= 3,
|
| 1202 |
+
lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)",
|
| 1203 |
+
)
|
| 1204 |
+
batch = self.shape[:-3]
|
| 1205 |
+
C_out = self.shape[-3] * downscale_factor**2
|
| 1206 |
+
HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor)
|
| 1207 |
+
n = len(batch)
|
| 1208 |
+
B_dims = range(n)
|
| 1209 |
+
C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5)
|
| 1210 |
+
return (
|
| 1211 |
+
self.view(
|
| 1212 |
+
*batch,
|
| 1213 |
+
self.shape[-3],
|
| 1214 |
+
HW_out[0],
|
| 1215 |
+
downscale_factor,
|
| 1216 |
+
HW_out[1],
|
| 1217 |
+
downscale_factor,
|
| 1218 |
+
)
|
| 1219 |
+
.permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim)
|
| 1220 |
+
.reshape(*batch, C_out, *HW_out)
|
| 1221 |
+
.clone(memory_format=utils.suggest_memory_format(self))
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
|
| 1225 |
+
# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg)
|
| 1226 |
+
celu_ = _make_inplace(celu)
|
| 1227 |
+
elu_ = _make_inplace(elu)
|
| 1228 |
+
mish_ = _make_inplace(mish)
|
| 1229 |
+
selu_ = _make_inplace(selu)
|
| 1230 |
+
threshold_ = _make_inplace(threshold)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-311.pyc
ADDED
|
Binary file (40.4 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/_numeric_suite.py
ADDED
|
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.ao.nn.quantized as nnq
|
| 4 |
+
import torch.ao.nn.quantized.dynamic as nnqd
|
| 5 |
+
from torch.ao.quantization import prepare
|
| 6 |
+
from typing import Dict, List, Optional, Any, Union, Callable, Set
|
| 7 |
+
|
| 8 |
+
from torch.ao.quantization.quantization_mappings import (
|
| 9 |
+
get_default_compare_output_module_list,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
|
| 13 |
+
nnqd.Linear,
|
| 14 |
+
nnq.Linear,
|
| 15 |
+
nnqd.LSTM,
|
| 16 |
+
nn.LSTM,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _find_match(
|
| 21 |
+
str_list: Union[Dict[str, Any], List[str]], key_str: str,
|
| 22 |
+
postfix: str,
|
| 23 |
+
) -> Optional[str]:
|
| 24 |
+
split_str = key_str.split(".")
|
| 25 |
+
if split_str[-1] == postfix:
|
| 26 |
+
match_string = "".join(key_str.split(".")[0:-1])
|
| 27 |
+
for s2 in str_list:
|
| 28 |
+
pattern1 = "".join(s2.split(".")[0:-1])
|
| 29 |
+
pattern2 = "".join(s2.split(".")[0:-2])
|
| 30 |
+
if match_string == pattern1:
|
| 31 |
+
return s2
|
| 32 |
+
if match_string == pattern2:
|
| 33 |
+
return s2
|
| 34 |
+
|
| 35 |
+
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
| 36 |
+
if postfix == "_packed_params":
|
| 37 |
+
match_string = "".join(key_str.split(".")[0:-2])
|
| 38 |
+
if len(match_string) == 0:
|
| 39 |
+
return None
|
| 40 |
+
for s2 in str_list:
|
| 41 |
+
pattern1 = "".join(s2.split(".")[0:-1])
|
| 42 |
+
pattern2 = "".join(s2.split(".")[0:-2])
|
| 43 |
+
if match_string == pattern1:
|
| 44 |
+
return s2
|
| 45 |
+
if match_string == pattern2:
|
| 46 |
+
return s2
|
| 47 |
+
return None
|
| 48 |
+
else:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compare_weights(
|
| 53 |
+
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
|
| 54 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 55 |
+
r"""Compare the weights of the float module with its corresponding quantized
|
| 56 |
+
module. Return a dict with key corresponding to module names and each entry being
|
| 57 |
+
a dictionary with two keys 'float' and 'quantized', containing the float and
|
| 58 |
+
quantized weights. This dict can be used to compare and compute the quantization
|
| 59 |
+
error of the weights of float and quantized models.
|
| 60 |
+
|
| 61 |
+
Example usage::
|
| 62 |
+
|
| 63 |
+
wt_compare_dict = compare_weights(
|
| 64 |
+
float_model.state_dict(), qmodel.state_dict())
|
| 65 |
+
for key in wt_compare_dict:
|
| 66 |
+
print(
|
| 67 |
+
key,
|
| 68 |
+
compute_error(
|
| 69 |
+
wt_compare_dict[key]['float'],
|
| 70 |
+
wt_compare_dict[key]['quantized'].dequantize()
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
float_dict: state dict of the float model
|
| 76 |
+
quantized_dict: state dict of the quantized model
|
| 77 |
+
|
| 78 |
+
Return:
|
| 79 |
+
weight_dict: dict with key corresponding to module names and each entry being
|
| 80 |
+
a dictionary with two keys 'float' and 'quantized', containing the float and
|
| 81 |
+
quantized weights
|
| 82 |
+
"""
|
| 83 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
|
| 84 |
+
weight_dict: Dict[str, Dict] = {}
|
| 85 |
+
for key in quantized_dict:
|
| 86 |
+
match_key = _find_match(float_dict, key, "weight")
|
| 87 |
+
if match_key is not None:
|
| 88 |
+
weight_dict[key] = {}
|
| 89 |
+
weight_dict[key]["float"] = float_dict[match_key]
|
| 90 |
+
weight_dict[key]["quantized"] = quantized_dict[key]
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# For matching "fc.weight" and "fc._packed_params._packed_params"
|
| 94 |
+
match_key = _find_match(float_dict, key, "_packed_params")
|
| 95 |
+
if match_key is not None:
|
| 96 |
+
weight_dict[key] = {}
|
| 97 |
+
weight_dict[key]["float"] = float_dict[match_key]
|
| 98 |
+
weight_dict[key]["quantized"] = quantized_dict[key][0]
|
| 99 |
+
|
| 100 |
+
# For LSTM
|
| 101 |
+
split_str = key.split(".")
|
| 102 |
+
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
|
| 103 |
+
layer = split_str[-2]
|
| 104 |
+
module_name = ".".join(split_str[:-3])
|
| 105 |
+
float_weight_ih_key = module_name + ".weight_ih_l" + layer
|
| 106 |
+
float_weight_hh_key = module_name + ".weight_hh_l" + layer
|
| 107 |
+
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
|
| 108 |
+
weight_dict[key] = {}
|
| 109 |
+
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
|
| 110 |
+
weight_dict[key]["quantized"] = (
|
| 111 |
+
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
|
| 112 |
+
)
|
| 113 |
+
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
|
| 114 |
+
weight_dict[key]["quantized"] = (
|
| 115 |
+
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return weight_dict
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _get_logger_dict_helper(
|
| 122 |
+
mod: nn.Module, target_dict: Dict[str, Any],
|
| 123 |
+
prefix: str = "",
|
| 124 |
+
) -> None:
|
| 125 |
+
r"""This is the helper function for get_logger_dict
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
mod: module we want to save all logger stats
|
| 129 |
+
prefix: prefix for the current module
|
| 130 |
+
target_dict: the dictionary used to save all logger stats
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def get_prefix(prefix):
|
| 134 |
+
return prefix if prefix == "" else prefix + "."
|
| 135 |
+
|
| 136 |
+
for name, child in mod.named_children():
|
| 137 |
+
if isinstance(child, Logger):
|
| 138 |
+
target_dict[get_prefix(prefix) + "stats"] = child.stats
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
for name, child in mod.named_children():
|
| 142 |
+
module_prefix = get_prefix(prefix) + name if prefix else name
|
| 143 |
+
_get_logger_dict_helper(child, target_dict, module_prefix)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
|
| 147 |
+
r"""Traverse the modules and save all logger stats into target dict.
|
| 148 |
+
This is mainly used for quantization accuracy debug.
|
| 149 |
+
|
| 150 |
+
Type of loggers supported:
|
| 151 |
+
ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
|
| 152 |
+
OutputLogger: used to log the outputs of the modules
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
mod: module we want to save all logger stats
|
| 156 |
+
prefix: prefix for the current module
|
| 157 |
+
|
| 158 |
+
Return:
|
| 159 |
+
target_dict: the dictionary used to save all logger stats
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
|
| 163 |
+
|
| 164 |
+
target_dict: Dict[str, Dict] = {}
|
| 165 |
+
_get_logger_dict_helper(mod, target_dict, prefix)
|
| 166 |
+
return target_dict
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Logger(nn.Module):
|
| 170 |
+
r"""Base class for stats logging
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
def __init__(self):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.stats = {}
|
| 176 |
+
# We only insert observer if the op is quantized with static quantization,
|
| 177 |
+
# which is identified by activation_observer.dtype == quint8. This is needed
|
| 178 |
+
# when attaching Logger as observer for FX mode
|
| 179 |
+
self.dtype = torch.quint8
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
"""
|
| 183 |
+
""" # blank docblock to make autodoc happy
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ShadowLogger(Logger):
|
| 188 |
+
r"""Class used in Shadow module to record the outputs of the original and
|
| 189 |
+
shadow modules.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.stats["float"] = []
|
| 195 |
+
self.stats["quantized"] = []
|
| 196 |
+
|
| 197 |
+
def forward(self, x, y):
|
| 198 |
+
"""
|
| 199 |
+
""" # blank docblock to make autodoc happy
|
| 200 |
+
if len(x) > 1:
|
| 201 |
+
x = x[0]
|
| 202 |
+
if len(y) > 1:
|
| 203 |
+
y = y[0]
|
| 204 |
+
self.stats["quantized"].append(x.detach())
|
| 205 |
+
self.stats["float"].append(y.detach())
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class OutputLogger(Logger):
|
| 209 |
+
r"""Class used to log the outputs of the module
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.stats["tensor_val"] = []
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
"""
|
| 219 |
+
""" # blank docblock to make autodoc happy
|
| 220 |
+
self.stats["tensor_val"].append(x)
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _convert_tuple_to_list(t: Any) -> Any:
|
| 225 |
+
return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _dequantize_tensor_list(t: Any) -> Any:
|
| 229 |
+
return (
|
| 230 |
+
[_dequantize_tensor_list(x) for x in t]
|
| 231 |
+
if type(t) is list
|
| 232 |
+
else t.dequantize()
|
| 233 |
+
if t.is_quantized
|
| 234 |
+
else t
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class Shadow(nn.Module):
|
| 239 |
+
r"""Shadow module attaches the float module to its matching quantized module
|
| 240 |
+
as the shadow. Then it uses Logger module to process the outputs of both
|
| 241 |
+
modules.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
q_module: module quantized from float_module that we want to shadow
|
| 245 |
+
float_module: float module used to shadow q_module
|
| 246 |
+
logger_cls: type of logger used to process the outputs of q_module and
|
| 247 |
+
float_module. ShadowLogger or custom loggers can be used.
|
| 248 |
+
"""
|
| 249 |
+
|
| 250 |
+
def __init__(self, q_module, float_module, logger_cls):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.orig_module = q_module
|
| 253 |
+
self.shadow_module = float_module
|
| 254 |
+
self.dequant = nnq.DeQuantize()
|
| 255 |
+
self.logger = logger_cls()
|
| 256 |
+
|
| 257 |
+
def forward(self, *x) -> torch.Tensor:
|
| 258 |
+
"""
|
| 259 |
+
""" # blank docblock to make autodoc happy
|
| 260 |
+
xl = _convert_tuple_to_list(x)
|
| 261 |
+
output = self.orig_module(*xl)
|
| 262 |
+
xl_float = _dequantize_tensor_list(xl)
|
| 263 |
+
shadow_output = self.shadow_module(*xl_float)
|
| 264 |
+
self.logger(output, shadow_output)
|
| 265 |
+
return output
|
| 266 |
+
|
| 267 |
+
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
""" # blank docblock to make autodoc happy
|
| 270 |
+
output = self.orig_module.add(x, y)
|
| 271 |
+
x = x.dequantize()
|
| 272 |
+
y = y.dequantize()
|
| 273 |
+
shadow_output = self.shadow_module.add(x, y)
|
| 274 |
+
self.logger(output, shadow_output)
|
| 275 |
+
return output
|
| 276 |
+
|
| 277 |
+
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
| 278 |
+
"""
|
| 279 |
+
""" # blank docblock to make autodoc happy
|
| 280 |
+
output = self.orig_module.add_scalar(x, y)
|
| 281 |
+
x = x.dequantize()
|
| 282 |
+
shadow_output = self.shadow_module.add_scalar(x, y)
|
| 283 |
+
self.logger(output, shadow_output)
|
| 284 |
+
return output
|
| 285 |
+
|
| 286 |
+
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 287 |
+
"""
|
| 288 |
+
""" # blank docblock to make autodoc happy
|
| 289 |
+
output = self.orig_module.mul(x, y)
|
| 290 |
+
x = x.dequantize()
|
| 291 |
+
y = y.dequantize()
|
| 292 |
+
shadow_output = self.shadow_module.mul(x, y)
|
| 293 |
+
self.logger(output, shadow_output)
|
| 294 |
+
return output
|
| 295 |
+
|
| 296 |
+
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
|
| 297 |
+
"""
|
| 298 |
+
""" # blank docblock to make autodoc happy
|
| 299 |
+
output = self.orig_module.mul_scalar(x, y)
|
| 300 |
+
x = x.dequantize()
|
| 301 |
+
shadow_output = self.shadow_module.mul_scalar(x, y)
|
| 302 |
+
self.logger(output, shadow_output)
|
| 303 |
+
return output
|
| 304 |
+
|
| 305 |
+
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
| 306 |
+
"""
|
| 307 |
+
""" # blank docblock to make autodoc happy
|
| 308 |
+
output = self.orig_module.cat(x, dim)
|
| 309 |
+
x = [y.dequantize() for y in x]
|
| 310 |
+
shadow_output = self.shadow_module.cat(x, dim)
|
| 311 |
+
self.logger(output, shadow_output)
|
| 312 |
+
return output
|
| 313 |
+
|
| 314 |
+
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
""" # blank docblock to make autodoc happy
|
| 317 |
+
output = self.orig_module.add_relu(x, y)
|
| 318 |
+
x = x.dequantize()
|
| 319 |
+
y = y.dequantize()
|
| 320 |
+
shadow_output = self.shadow_module.add_relu(x, y)
|
| 321 |
+
self.logger(output, shadow_output)
|
| 322 |
+
return output
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def prepare_model_with_stubs(
|
| 326 |
+
float_module: nn.Module, q_module: nn.Module,
|
| 327 |
+
module_swap_list: Set[type], logger_cls: Callable,
|
| 328 |
+
) -> None:
|
| 329 |
+
r"""Prepare the model by attaching the float module to its matching quantized
|
| 330 |
+
module as the shadow if the float module type is in module_swap_list.
|
| 331 |
+
|
| 332 |
+
Example usage::
|
| 333 |
+
|
| 334 |
+
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
|
| 335 |
+
q_model(data)
|
| 336 |
+
ob_dict = get_logger_dict(q_model)
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
float_module: float module used to generate the q_module
|
| 340 |
+
q_module: module quantized from float_module
|
| 341 |
+
module_swap_list: list of float module types to attach the shadow
|
| 342 |
+
logger_cls: type of logger to be used in shadow module to process the outputs of
|
| 343 |
+
quantized module and its float shadow module
|
| 344 |
+
"""
|
| 345 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_with_stubs")
|
| 346 |
+
|
| 347 |
+
float_module_children = {}
|
| 348 |
+
for name, mod in float_module.named_children():
|
| 349 |
+
float_module_children[name] = mod
|
| 350 |
+
|
| 351 |
+
reassign = {}
|
| 352 |
+
for name, mod in q_module.named_children():
|
| 353 |
+
|
| 354 |
+
if name not in float_module_children:
|
| 355 |
+
continue
|
| 356 |
+
|
| 357 |
+
float_mod = float_module_children[name]
|
| 358 |
+
|
| 359 |
+
if type(float_mod) not in module_swap_list:
|
| 360 |
+
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
|
| 361 |
+
|
| 362 |
+
# Insert shadow module only if the module is not of the same type as
|
| 363 |
+
# the floating point module
|
| 364 |
+
if type(float_mod) in module_swap_list and not _is_identical_module_type(mod, float_mod):
|
| 365 |
+
reassign[name] = Shadow(mod, float_mod, logger_cls)
|
| 366 |
+
|
| 367 |
+
for key, value in reassign.items():
|
| 368 |
+
q_module._modules[key] = value
|
| 369 |
+
|
| 370 |
+
def _is_identical_module_type(mod1, mod2):
|
| 371 |
+
# Compare if two modules have the same dtype
|
| 372 |
+
mod1_module_types = [type(mod) for mod in mod1.modules()]
|
| 373 |
+
mod2_module_types = [type(mod) for mod in mod2.modules()]
|
| 374 |
+
return mod1_module_types == mod2_module_types
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def compare_model_stub(
|
| 379 |
+
float_model: nn.Module, q_model: nn.Module, module_swap_list: Set[type],
|
| 380 |
+
*data, logger_cls=ShadowLogger
|
| 381 |
+
) -> Dict[str, Dict]:
|
| 382 |
+
r"""Compare quantized module in a model with its floating point counterpart,
|
| 383 |
+
feeding both of them the same input. Return a dict with key corresponding to
|
| 384 |
+
module names and each entry being a dictionary with two keys 'float' and
|
| 385 |
+
'quantized', containing the output tensors of quantized and its matching
|
| 386 |
+
float shadow module. This dict can be used to compare and compute the module
|
| 387 |
+
level quantization error.
|
| 388 |
+
|
| 389 |
+
This function first call prepare_model_with_stubs() to swap the quantized
|
| 390 |
+
module that we want to compare with the Shadow module, which takes quantized
|
| 391 |
+
module, corresponding float module and logger as input, and creates a forward
|
| 392 |
+
path inside to make the float module to shadow quantized module sharing the
|
| 393 |
+
same input. The logger can be customizable, default logger is ShadowLogger
|
| 394 |
+
and it will save the outputs of the quantized module and float module that
|
| 395 |
+
can be used to compute the module level quantization error.
|
| 396 |
+
|
| 397 |
+
Example usage::
|
| 398 |
+
|
| 399 |
+
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
|
| 400 |
+
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
|
| 401 |
+
for key in ob_dict:
|
| 402 |
+
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
float_model: float model used to generate the q_model
|
| 406 |
+
q_model: model quantized from float_model
|
| 407 |
+
module_swap_list: list of float module types at which shadow modules will
|
| 408 |
+
be attached.
|
| 409 |
+
data: input data used to run the prepared q_model
|
| 410 |
+
logger_cls: type of logger to be used in shadow module to process the outputs of
|
| 411 |
+
quantized module and its float shadow module
|
| 412 |
+
"""
|
| 413 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
|
| 414 |
+
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
|
| 415 |
+
q_model(*data)
|
| 416 |
+
ob_dict = get_logger_dict(q_model)
|
| 417 |
+
return ob_dict
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def get_matching_activations(
|
| 421 |
+
float_module: nn.Module, q_module: nn.Module,
|
| 422 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 423 |
+
r"""Find the matching activation between float and quantized modules.
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
float_module: float module used to generate the q_module
|
| 427 |
+
q_module: module quantized from float_module
|
| 428 |
+
|
| 429 |
+
Return:
|
| 430 |
+
act_dict: dict with key corresponding to quantized module names and each
|
| 431 |
+
entry being a dictionary with two keys 'float' and 'quantized', containing
|
| 432 |
+
the matching float and quantized activations
|
| 433 |
+
"""
|
| 434 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_matching_activations")
|
| 435 |
+
float_dict = get_logger_dict(float_module)
|
| 436 |
+
quantized_dict = get_logger_dict(q_module)
|
| 437 |
+
act_dict: Dict[str, Dict] = {}
|
| 438 |
+
for key in quantized_dict:
|
| 439 |
+
if len(quantized_dict[key]["tensor_val"]) == 0:
|
| 440 |
+
continue
|
| 441 |
+
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
| 442 |
+
if match_key is not None:
|
| 443 |
+
act_dict[key] = {}
|
| 444 |
+
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
|
| 445 |
+
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
|
| 446 |
+
return act_dict
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def prepare_model_outputs(
|
| 450 |
+
float_module: nn.Module,
|
| 451 |
+
q_module: nn.Module,
|
| 452 |
+
logger_cls=OutputLogger,
|
| 453 |
+
allow_list=None
|
| 454 |
+
) -> None:
|
| 455 |
+
r"""Prepare the model by attaching the logger to both float module
|
| 456 |
+
and quantized module if they are in the allow_list.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
float_module: float module used to generate the q_module
|
| 460 |
+
q_module: module quantized from float_module
|
| 461 |
+
logger_cls: type of logger to be attached to float_module and q_module
|
| 462 |
+
allow_list: list of module types to attach logger
|
| 463 |
+
"""
|
| 464 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.prepare_model_outputs")
|
| 465 |
+
if allow_list is None:
|
| 466 |
+
allow_list = get_default_compare_output_module_list()
|
| 467 |
+
|
| 468 |
+
qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
|
| 469 |
+
float_module.qconfig = qconfig_debug # type: ignore[assignment]
|
| 470 |
+
prepare(float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={})
|
| 471 |
+
q_module.qconfig = qconfig_debug # type: ignore[assignment]
|
| 472 |
+
prepare(
|
| 473 |
+
q_module,
|
| 474 |
+
inplace=True,
|
| 475 |
+
allow_list=allow_list,
|
| 476 |
+
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
|
| 477 |
+
prepare_custom_config_dict={}
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def compare_model_outputs(
|
| 482 |
+
float_model: nn.Module,
|
| 483 |
+
q_model: nn.Module,
|
| 484 |
+
*data,
|
| 485 |
+
logger_cls=OutputLogger,
|
| 486 |
+
allow_list=None
|
| 487 |
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
| 488 |
+
r"""Compare output activations between float and quantized models at
|
| 489 |
+
corresponding locations for the same input. Return a dict with key corresponding
|
| 490 |
+
to quantized module names and each entry being a dictionary with two keys
|
| 491 |
+
'float' and 'quantized', containing the activations of quantized model and
|
| 492 |
+
float model at matching locations. This dict can be used to compare and
|
| 493 |
+
compute the propagation quantization error.
|
| 494 |
+
|
| 495 |
+
Example usage::
|
| 496 |
+
|
| 497 |
+
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
| 498 |
+
for key in act_compare_dict:
|
| 499 |
+
print(
|
| 500 |
+
key,
|
| 501 |
+
compute_error(
|
| 502 |
+
act_compare_dict[key]['float'],
|
| 503 |
+
act_compare_dict[key]['quantized'].dequantize()
|
| 504 |
+
)
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
float_model: float model used to generate the q_model
|
| 509 |
+
q_model: model quantized from float_model
|
| 510 |
+
data: input data used to run the prepared float_model and q_model
|
| 511 |
+
logger_cls: type of logger to be attached to float_module and q_module
|
| 512 |
+
allow_list: list of module types to attach logger
|
| 513 |
+
|
| 514 |
+
Return:
|
| 515 |
+
act_compare_dict: dict with key corresponding to quantized module names
|
| 516 |
+
and each entry being a dictionary with two keys 'float' and 'quantized',
|
| 517 |
+
containing the matching float and quantized activations
|
| 518 |
+
"""
|
| 519 |
+
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_outputs")
|
| 520 |
+
if allow_list is None:
|
| 521 |
+
allow_list = get_default_compare_output_module_list()
|
| 522 |
+
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
|
| 523 |
+
float_model(*data)
|
| 524 |
+
q_model(*data)
|
| 525 |
+
act_compare_dict = get_matching_activations(float_model, q_model)
|
| 526 |
+
return act_compare_dict
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/graph_matcher.cpython-311.pyc
ADDED
|
Binary file (18.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/mappings.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-311.pyc
ADDED
|
Binary file (42.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/__pycache__/weight_utils.cpython-311.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/ns/fx/weight_utils.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.ao.nn.quantized.dynamic as nnqd
|
| 5 |
+
import torch.ao.nn.quantized as nnq
|
| 6 |
+
import torch.ao.nn.intrinsic.qat as nniqat
|
| 7 |
+
import torch.ao.nn.qat as nnqat
|
| 8 |
+
import torch.ao.nn.intrinsic as nni
|
| 9 |
+
import torch.ao.nn.intrinsic.quantized as nniq
|
| 10 |
+
toq = torch.ops.quantized
|
| 11 |
+
from torch.fx import GraphModule
|
| 12 |
+
from torch.fx.graph import Node
|
| 13 |
+
|
| 14 |
+
from .utils import (
|
| 15 |
+
get_target_type_str,
|
| 16 |
+
getattr_from_fqn,
|
| 17 |
+
return_first_non_observer_node,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .ns_types import (
|
| 21 |
+
NSSingleResultValuesType,
|
| 22 |
+
NSSingleResultType,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from typing import List, Optional, Dict, Callable
|
| 26 |
+
|
| 27 |
+
def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
|
| 28 |
+
return mod.weight.detach() # type: ignore[operator]
|
| 29 |
+
|
| 30 |
+
def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
|
| 31 |
+
return mod[0].weight.detach() # type: ignore[index]
|
| 32 |
+
|
| 33 |
+
def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
|
| 34 |
+
return mod._weight_bias()[0] # type: ignore[operator]
|
| 35 |
+
|
| 36 |
+
def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
|
| 37 |
+
res = []
|
| 38 |
+
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
|
| 39 |
+
if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
|
| 40 |
+
param_value = mod._flat_weights[idx].detach() # type: ignore[index]
|
| 41 |
+
res.append(param_value)
|
| 42 |
+
return res
|
| 43 |
+
|
| 44 |
+
def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
|
| 45 |
+
res = []
|
| 46 |
+
for weight_value in mod._all_weight_values: # type: ignore[union-attr]
|
| 47 |
+
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
| 48 |
+
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
|
| 49 |
+
return res
|
| 50 |
+
|
| 51 |
+
def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
|
| 52 |
+
if (
|
| 53 |
+
isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
|
| 54 |
+
):
|
| 55 |
+
return mod.weight.detach()
|
| 56 |
+
elif (
|
| 57 |
+
isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d))
|
| 58 |
+
):
|
| 59 |
+
return mod[0].weight.detach()
|
| 60 |
+
else:
|
| 61 |
+
return mod._weight_bias()[0] # type: ignore[operator]
|
| 62 |
+
|
| 63 |
+
def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
|
| 64 |
+
if isinstance(mod, nn.Linear):
|
| 65 |
+
return mod.weight.detach()
|
| 66 |
+
elif isinstance(mod, nni.LinearReLU):
|
| 67 |
+
return mod[0].weight.detach()
|
| 68 |
+
else:
|
| 69 |
+
return mod._weight_bias()[0] # type: ignore[operator]
|
| 70 |
+
|
| 71 |
+
def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
|
| 72 |
+
# TODO(future PR): make more generic, handle everything
|
| 73 |
+
if isinstance(mod, nn.LSTM):
|
| 74 |
+
res = []
|
| 75 |
+
for idx, param_name in enumerate(mod._flat_weights_names):
|
| 76 |
+
if 'weight_ih_l' in param_name or 'weight_hh_l' in param_name:
|
| 77 |
+
param_value = mod._flat_weights[idx].detach()
|
| 78 |
+
res.append(param_value)
|
| 79 |
+
return res
|
| 80 |
+
else:
|
| 81 |
+
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
|
| 82 |
+
res = []
|
| 83 |
+
for weight_value in mod._all_weight_values:
|
| 84 |
+
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
|
| 85 |
+
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
|
| 86 |
+
return res
|
| 87 |
+
|
| 88 |
+
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
| 89 |
+
# traverse backwards from the weight arg, accounting for any observers
|
| 90 |
+
weight_arg_node = node.args[1]
|
| 91 |
+
assert isinstance(weight_arg_node, Node)
|
| 92 |
+
weight_node = return_first_non_observer_node(weight_arg_node, gm)
|
| 93 |
+
assert isinstance(weight_node, Node)
|
| 94 |
+
assert weight_node.op == 'get_attr'
|
| 95 |
+
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
| 96 |
+
return weight.detach()
|
| 97 |
+
|
| 98 |
+
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
| 99 |
+
# qconv state is arg 1
|
| 100 |
+
qconv_state_node = node.args[1]
|
| 101 |
+
assert isinstance(qconv_state_node, Node)
|
| 102 |
+
assert qconv_state_node.op == 'get_attr'
|
| 103 |
+
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
|
| 104 |
+
return qconv_state_obj.weight()
|
| 105 |
+
|
| 106 |
+
def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
| 107 |
+
# traverse backwards from the weight arg, accounting for any observers
|
| 108 |
+
# supported patterns:
|
| 109 |
+
# weight -> obs -> linear
|
| 110 |
+
# weight -> to(torch.float16) -> dequantize -> linear
|
| 111 |
+
linear_second_arg = node.args[1]
|
| 112 |
+
assert isinstance(linear_second_arg, Node)
|
| 113 |
+
|
| 114 |
+
if linear_second_arg.op == 'call_module':
|
| 115 |
+
# weight -> obs -> linear
|
| 116 |
+
weight_arg_node = node.args[1]
|
| 117 |
+
assert isinstance(weight_arg_node, Node)
|
| 118 |
+
weight_node = weight_arg_node.args[0]
|
| 119 |
+
assert isinstance(weight_node, Node)
|
| 120 |
+
assert weight_node.op == 'get_attr'
|
| 121 |
+
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
| 122 |
+
return weight.detach()
|
| 123 |
+
elif linear_second_arg.op == 'call_method':
|
| 124 |
+
# weight -> to(torch.float16) -> dequantize -> linear
|
| 125 |
+
assert linear_second_arg.op == 'call_method'
|
| 126 |
+
dequant_node = node.args[1]
|
| 127 |
+
assert isinstance(dequant_node, Node)
|
| 128 |
+
to_fp16_node = dequant_node.args[0]
|
| 129 |
+
assert isinstance(to_fp16_node, Node)
|
| 130 |
+
# extract the dtype, so we can cast to it before returning
|
| 131 |
+
target_dtype = to_fp16_node.args[1]
|
| 132 |
+
weight_node = to_fp16_node.args[0]
|
| 133 |
+
assert isinstance(weight_node, Node)
|
| 134 |
+
assert weight_node.op == 'get_attr'
|
| 135 |
+
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
|
| 136 |
+
# return the weight with fp16 cast
|
| 137 |
+
return weight.detach().to(target_dtype)
|
| 138 |
+
else:
|
| 139 |
+
assert linear_second_arg.op == 'get_attr'
|
| 140 |
+
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
|
| 141 |
+
return weight.detach()
|
| 142 |
+
|
| 143 |
+
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
|
| 144 |
+
# packed weight is arg 1
|
| 145 |
+
packed_weight_node = node.args[1]
|
| 146 |
+
assert isinstance(packed_weight_node, Node)
|
| 147 |
+
assert packed_weight_node.op == 'get_attr'
|
| 148 |
+
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
|
| 149 |
+
# TODO(future PR): why does packed_weight.unpack() not work?
|
| 150 |
+
(weight, _bias), _name = packed_weight.__getstate__()
|
| 151 |
+
return weight
|
| 152 |
+
|
| 153 |
+
def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
|
| 154 |
+
|
| 155 |
+
op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
|
| 156 |
+
'call_module': {
|
| 157 |
+
# Conv1d
|
| 158 |
+
nn.Conv1d: mod_weight_detach,
|
| 159 |
+
nni.ConvReLU1d: mod_0_weight_detach,
|
| 160 |
+
nnq.Conv1d: mod_weight_bias_0,
|
| 161 |
+
nnqat.Conv1d: mod_weight_detach,
|
| 162 |
+
nniqat.ConvBn1d: mod_weight_detach,
|
| 163 |
+
nniqat.ConvBnReLU1d: mod_weight_detach,
|
| 164 |
+
nniqat.ConvReLU1d: mod_weight_detach,
|
| 165 |
+
nniq.ConvReLU1d: mod_weight_bias_0,
|
| 166 |
+
# Conv2d
|
| 167 |
+
nn.Conv2d: mod_weight_detach,
|
| 168 |
+
nni.ConvReLU2d: mod_0_weight_detach,
|
| 169 |
+
nnq.Conv2d: mod_weight_bias_0,
|
| 170 |
+
nnqat.Conv2d: mod_weight_detach,
|
| 171 |
+
nniqat.ConvBn2d: mod_weight_detach,
|
| 172 |
+
nniqat.ConvBnReLU2d: mod_weight_detach,
|
| 173 |
+
nniqat.ConvReLU2d: mod_weight_detach,
|
| 174 |
+
nniq.ConvReLU2d: mod_weight_bias_0,
|
| 175 |
+
# Conv3d
|
| 176 |
+
nn.Conv3d: mod_weight_detach,
|
| 177 |
+
nni.ConvReLU3d: mod_0_weight_detach,
|
| 178 |
+
nnq.Conv3d: mod_weight_bias_0,
|
| 179 |
+
nnqat.Conv3d: mod_weight_detach,
|
| 180 |
+
nniqat.ConvBn3d: mod_weight_detach,
|
| 181 |
+
nniqat.ConvBnReLU3d: mod_weight_detach,
|
| 182 |
+
nniqat.ConvReLU3d: mod_weight_detach,
|
| 183 |
+
nniq.ConvReLU3d: mod_weight_bias_0,
|
| 184 |
+
# Linear
|
| 185 |
+
nn.Linear: mod_weight_detach,
|
| 186 |
+
nnq.Linear: mod_weight_bias_0,
|
| 187 |
+
nni.LinearReLU: mod_0_weight_detach,
|
| 188 |
+
nniq.LinearReLU: mod_weight_bias_0,
|
| 189 |
+
nnqat.Linear: mod_weight_detach,
|
| 190 |
+
nnqd.Linear: mod_weight_bias_0,
|
| 191 |
+
nniqat.LinearReLU: mod_weight_detach,
|
| 192 |
+
nniqat.LinearBn1d: mod_weight_detach,
|
| 193 |
+
nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
|
| 194 |
+
# LSTM
|
| 195 |
+
nn.LSTM: get_lstm_weight,
|
| 196 |
+
nnqd.LSTM: get_qlstm_weight,
|
| 197 |
+
},
|
| 198 |
+
'call_function': {
|
| 199 |
+
# Conv
|
| 200 |
+
F.conv1d: get_conv_fun_weight,
|
| 201 |
+
F.conv2d: get_conv_fun_weight,
|
| 202 |
+
F.conv3d: get_conv_fun_weight,
|
| 203 |
+
toq.conv1d: get_qconv_fun_weight,
|
| 204 |
+
toq.conv2d: get_qconv_fun_weight,
|
| 205 |
+
toq.conv3d: get_qconv_fun_weight,
|
| 206 |
+
toq.conv1d_relu: get_qconv_fun_weight,
|
| 207 |
+
toq.conv2d_relu: get_qconv_fun_weight,
|
| 208 |
+
toq.conv3d_relu: get_qconv_fun_weight,
|
| 209 |
+
# Linear
|
| 210 |
+
F.linear: get_linear_fun_weight,
|
| 211 |
+
toq.linear: get_qlinear_fun_weight,
|
| 212 |
+
toq.linear_relu: get_qlinear_fun_weight,
|
| 213 |
+
},
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return op_to_type_to_weight_extraction_fn
|
| 217 |
+
|
| 218 |
+
def extract_weight_from_node(
|
| 219 |
+
node: Node,
|
| 220 |
+
gm: GraphModule,
|
| 221 |
+
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
| 222 |
+
) -> Optional[NSSingleResultType]:
|
| 223 |
+
res_type = NSSingleResultValuesType.WEIGHT.value
|
| 224 |
+
|
| 225 |
+
# Not all graphmodules have _node_name_to_scope, so only fill it
|
| 226 |
+
# out if it exists.
|
| 227 |
+
fqn = None
|
| 228 |
+
if hasattr(gm, '_node_name_to_scope'):
|
| 229 |
+
fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]
|
| 230 |
+
|
| 231 |
+
if op_to_type_to_weight_extraction_fn is None:
|
| 232 |
+
op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
|
| 233 |
+
|
| 234 |
+
ref_node_type = get_target_type_str(node, gm)
|
| 235 |
+
# for extracting weights, these are always the same
|
| 236 |
+
prev_node_type = ref_node_type
|
| 237 |
+
|
| 238 |
+
if node.op == 'call_function':
|
| 239 |
+
function_mapping = op_to_type_to_weight_extraction_fn['call_function']
|
| 240 |
+
for target_fn_type, weight_extraction_fn in function_mapping.items():
|
| 241 |
+
if node.target == target_fn_type:
|
| 242 |
+
weight = weight_extraction_fn(node, gm)
|
| 243 |
+
return {
|
| 244 |
+
'type': res_type,
|
| 245 |
+
'values': [weight],
|
| 246 |
+
'prev_node_name': node.name,
|
| 247 |
+
'prev_node_target_type': prev_node_type,
|
| 248 |
+
'ref_node_name': node.name,
|
| 249 |
+
'ref_node_target_type': ref_node_type,
|
| 250 |
+
'index_within_arg': 0,
|
| 251 |
+
'index_of_arg': 0,
|
| 252 |
+
'fqn': fqn,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
elif node.op == 'call_module':
|
| 256 |
+
# for call_module, we need to look up the modules to do the type check
|
| 257 |
+
assert isinstance(node.target, str)
|
| 258 |
+
mod = getattr_from_fqn(gm, node.target)
|
| 259 |
+
module_mapping = op_to_type_to_weight_extraction_fn['call_module']
|
| 260 |
+
for target_mod_type, weight_extraction_fn in module_mapping.items():
|
| 261 |
+
if type(mod) == target_mod_type:
|
| 262 |
+
weight = weight_extraction_fn(mod)
|
| 263 |
+
return {
|
| 264 |
+
'type': res_type,
|
| 265 |
+
'values': [weight],
|
| 266 |
+
'prev_node_name': node.name,
|
| 267 |
+
'prev_node_target_type': prev_node_type,
|
| 268 |
+
'ref_node_name': node.name,
|
| 269 |
+
'ref_node_target_type': ref_node_type,
|
| 270 |
+
'index_within_arg': 0,
|
| 271 |
+
'index_of_arg': 0,
|
| 272 |
+
'fqn': fqn,
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
return None
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/fake_quantize.cpython-311.pyc
ADDED
|
Binary file (25.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/qconfig.cpython-311.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/ao/quantization/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (30.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/__pycache__/graph_signature.cpython-311.pyc
ADDED
|
Binary file (27.6 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/export/_remove_auto_functionalized_pass.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import operator
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch._higher_order_ops.auto_functionalize import (
|
| 12 |
+
auto_functionalized,
|
| 13 |
+
get_mutable_arg_names,
|
| 14 |
+
)
|
| 15 |
+
from torch.export import ExportedProgram
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def unsafe_remove_auto_functionalized_pass(
|
| 19 |
+
ep: ExportedProgram,
|
| 20 |
+
) -> ExportedProgram:
|
| 21 |
+
"""
|
| 22 |
+
This pass removes an instances of the higher order op 'auto_functionalized',
|
| 23 |
+
and modifies the calling EP inplace to have the original mutator op.
|
| 24 |
+
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
|
| 25 |
+
"""
|
| 26 |
+
auto_functionalize_nodes: List[torch.fx.Node] = []
|
| 27 |
+
for module in ep.graph_module.modules():
|
| 28 |
+
if not isinstance(module, torch.fx.GraphModule):
|
| 29 |
+
continue
|
| 30 |
+
for node in ep.graph.nodes:
|
| 31 |
+
if node.op == "call_function" and node.target is auto_functionalized:
|
| 32 |
+
auto_functionalize_nodes.append(node)
|
| 33 |
+
|
| 34 |
+
# Update every use of the HOP
|
| 35 |
+
for node in reversed(auto_functionalize_nodes):
|
| 36 |
+
func = node.args[0]
|
| 37 |
+
original_kwargs = node.kwargs
|
| 38 |
+
assert isinstance(func, torch._ops.OpOverload)
|
| 39 |
+
|
| 40 |
+
with ep.graph.inserting_before(node):
|
| 41 |
+
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
|
| 42 |
+
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
|
| 43 |
+
for k, v in node.meta.items():
|
| 44 |
+
new_node.meta[k] = v
|
| 45 |
+
|
| 46 |
+
# Replace auto_functionalize(func, args) with just func(args)
|
| 47 |
+
node.replace_all_uses_with(new_node)
|
| 48 |
+
|
| 49 |
+
mutable_args_names = get_mutable_arg_names(new_node.target)
|
| 50 |
+
output_specs = ep.graph_signature.output_specs
|
| 51 |
+
|
| 52 |
+
# update the users of the auto_func node (the getitem nodes)
|
| 53 |
+
for user in list(new_node.users.keys()):
|
| 54 |
+
assert user.target == operator.getitem
|
| 55 |
+
# getitem corresponding to a mutated input, just replace all uses with the original input
|
| 56 |
+
if user.args[1] >= len(func._schema.returns):
|
| 57 |
+
assert user.args[1] <= len(func._schema.returns) + len(
|
| 58 |
+
mutable_args_names
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# If the result of getitem was used in an output node, update the output spec with the correct name
|
| 62 |
+
adusted_index = user.args[1] - len(func._schema.returns)
|
| 63 |
+
original_arg = original_kwargs[mutable_args_names[adusted_index]]
|
| 64 |
+
for spec in output_specs:
|
| 65 |
+
if spec.arg.name == user.name:
|
| 66 |
+
spec.arg.name = original_arg.name # pyre-ignore
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
|
| 70 |
+
# of the getitem calls following the HOP.
|
| 71 |
+
user.replace_all_uses_with(
|
| 72 |
+
original_kwargs[mutable_args_names[adusted_index]]
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if len(func._schema.returns) == 1:
|
| 76 |
+
# If the function has 1 return then it will just directly return the
|
| 77 |
+
# result -- we don't need a getitem. So we can replace all the
|
| 78 |
+
# getitem(auto_functionalized, 0) with just the note itself.
|
| 79 |
+
for user in list(new_node.users.keys()):
|
| 80 |
+
if user.args[1] == 0:
|
| 81 |
+
user.replace_all_uses_with(new_node)
|
| 82 |
+
|
| 83 |
+
# Same case as above, update the output spec if getitem result used in an output node
|
| 84 |
+
for spec in output_specs:
|
| 85 |
+
if spec.arg.name == user.name:
|
| 86 |
+
spec.arg.name = new_node.name
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
|
| 90 |
+
ep.graph.erase_node(node)
|
| 91 |
+
|
| 92 |
+
ep.graph.eliminate_dead_code()
|
| 93 |
+
return ep
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
op_add = '+'
|
| 2 |
+
op_sub = '-'
|
| 3 |
+
op_mul = '*'
|
| 4 |
+
op_div = '/'
|
| 5 |
+
op_eq = '='
|
| 6 |
+
op_neq = '!='
|
| 7 |
+
op_imp = '=>'
|
| 8 |
+
op_matching = '⊳'
|
| 9 |
+
op_consistency = '~'
|
| 10 |
+
op_precision = '⊑'
|
| 11 |
+
op_leq = '≤'
|
| 12 |
+
op_lt = '<'
|
| 13 |
+
op_gt = '>'
|
| 14 |
+
op_mod = '%'
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import z3 # type: ignore[import]
|
| 3 |
+
HAS_Z3 = True
|
| 4 |
+
# dynamic type
|
| 5 |
+
dyn = z3.DeclareSort('Dyn')
|
| 6 |
+
dyn_type = z3.Const('dyn', dyn)
|
| 7 |
+
|
| 8 |
+
# dimension
|
| 9 |
+
dim = z3.Datatype('dim')
|
| 10 |
+
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
|
| 11 |
+
dim = dim.create()
|
| 12 |
+
|
| 13 |
+
# tensors
|
| 14 |
+
tensor_type = z3.Datatype('TensorType')
|
| 15 |
+
tensor_type.declare('Dyn', ('dyn', dyn))
|
| 16 |
+
tensor_type.declare('tensor1', ('0', dim))
|
| 17 |
+
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
|
| 18 |
+
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
|
| 19 |
+
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
|
| 20 |
+
tensor_type = tensor_type.create()
|
| 21 |
+
|
| 22 |
+
# create dimension
|
| 23 |
+
D = dim.dim
|
| 24 |
+
|
| 25 |
+
z3_dyn = tensor_type.Dyn(dyn_type)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
HAS_Z3 = False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: disable-error-code=attr-defined
|
| 2 |
+
from .core import unify, reify # noqa: F403
|
| 3 |
+
from .more import unifiable # noqa: F403
|
| 4 |
+
from .variable import var, isvar, vars, variables, Var # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .core import dispatch
|
| 2 |
+
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
|
| 3 |
+
MDNotImplementedError)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc
ADDED
|
Binary file (8.72 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import graph_drawer
|
| 2 |
+
from . import graph_manipulation
|
| 3 |
+
from . import net_min_base
|
| 4 |
+
from . import operator_support
|
| 5 |
+
from . import param_fetch
|
| 6 |
+
from . import reinplace
|
| 7 |
+
from . import shape_prop
|
| 8 |
+
from . import split_module
|
| 9 |
+
from . import split_utils
|
| 10 |
+
from . import splitter_base
|
| 11 |
+
from . import tools_common
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (231 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, NamedTuple, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.fx._compatibility import compatibility
|
| 5 |
+
from torch.fx.graph import Graph
|
| 6 |
+
from torch.fx.graph_module import GraphModule
|
| 7 |
+
from torch.fx.node import (
|
| 8 |
+
map_arg,
|
| 9 |
+
Node,
|
| 10 |
+
Target,
|
| 11 |
+
)
|
| 12 |
+
from torch.fx.passes.shape_prop import ShapeProp
|
| 13 |
+
|
| 14 |
+
__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
|
| 15 |
+
'get_size_of_node']
|
| 16 |
+
|
| 17 |
+
@compatibility(is_backward_compatible=False)
|
| 18 |
+
def replace_target_nodes_with(
|
| 19 |
+
fx_module: GraphModule,
|
| 20 |
+
old_op: str,
|
| 21 |
+
old_target: Target,
|
| 22 |
+
new_op: str,
|
| 23 |
+
new_target: Target,
|
| 24 |
+
):
|
| 25 |
+
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
|
| 26 |
+
and updates them to match the new op code and target"""
|
| 27 |
+
new_graph = Graph()
|
| 28 |
+
val_map: Dict[Node, Node] = {}
|
| 29 |
+
for node in fx_module.graph.nodes:
|
| 30 |
+
if node.op == old_op and node.target == old_target:
|
| 31 |
+
args = map_arg(node.args, lambda n: val_map[n])
|
| 32 |
+
kwargs = map_arg(node.kwargs, lambda n: val_map[n])
|
| 33 |
+
assert isinstance(args, tuple)
|
| 34 |
+
assert isinstance(kwargs, dict)
|
| 35 |
+
val_map[node] = new_graph.create_node(
|
| 36 |
+
new_op, new_target, args, kwargs, node.name
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
|
| 40 |
+
fx_module.graph = new_graph
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@compatibility(is_backward_compatible=False)
|
| 44 |
+
class size_bytes(NamedTuple):
|
| 45 |
+
output_size: int
|
| 46 |
+
total_size: int
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@compatibility(is_backward_compatible=False)
|
| 50 |
+
def get_size_of_all_nodes(
|
| 51 |
+
fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
|
| 52 |
+
) -> None:
|
| 53 |
+
"""Given a fx graph module, update each node with its total size (weights + bias + output)
|
| 54 |
+
and its output_size(output). For a non-module node, the total size is the output size.
|
| 55 |
+
return total size"""
|
| 56 |
+
if args is not None:
|
| 57 |
+
# Mark shape and dtype for each node (node.shape and node.dtype)
|
| 58 |
+
ShapeProp(fx_module).propagate(*args)
|
| 59 |
+
# Calculate the total size of the whole fx graph
|
| 60 |
+
total_size_of_graph = 0.0
|
| 61 |
+
for node in fx_module.graph.nodes:
|
| 62 |
+
if node.op == "output":
|
| 63 |
+
break
|
| 64 |
+
node.size_bytes = get_size_of_node(fx_module, node)
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@compatibility(is_backward_compatible=False)
|
| 69 |
+
def get_tensor_meta(node: Node) -> Any:
|
| 70 |
+
tensor_meta = node.meta.get("tensor_meta")
|
| 71 |
+
|
| 72 |
+
if not tensor_meta:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
f"Node {node} has no tensor metadata associated with it! "
|
| 75 |
+
f"Check that shape propagation has run."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return tensor_meta
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@compatibility(is_backward_compatible=False)
|
| 82 |
+
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
|
| 83 |
+
"""Given a node with node.dtype and node.shape, return its total size and its output size.
|
| 84 |
+
total_size = weights + bias + output_size
|
| 85 |
+
"""
|
| 86 |
+
# Total num of elements
|
| 87 |
+
total_num_of_elems = 0
|
| 88 |
+
# For a module, conside all parameters
|
| 89 |
+
if node.op == "call_module":
|
| 90 |
+
submodule_dict = dict(fx_module.named_modules())
|
| 91 |
+
submodule = submodule_dict[node.target]
|
| 92 |
+
parameters = submodule.named_parameters()
|
| 93 |
+
# Parameters are named tuples
|
| 94 |
+
for name, p in parameters:
|
| 95 |
+
total_num_of_elems += p.numel()
|
| 96 |
+
# Don't forget the output size
|
| 97 |
+
# node.shape is the shape of this node's output
|
| 98 |
+
tensor_meta = get_tensor_meta(node)
|
| 99 |
+
output_elem = tensor_meta.shape.numel()
|
| 100 |
+
total_num_of_elems += output_elem
|
| 101 |
+
# Assume for now if it's quantized then it's qint8 or quint8
|
| 102 |
+
if tensor_meta.is_quantized:
|
| 103 |
+
size_per_elem_bytes = torch._empty_affine_quantized(
|
| 104 |
+
[], dtype=tensor_meta.dtype
|
| 105 |
+
).element_size()
|
| 106 |
+
else:
|
| 107 |
+
size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
|
| 108 |
+
total_size = size_per_elem_bytes * total_num_of_elems
|
| 109 |
+
output_size = size_per_elem_bytes * output_elem
|
| 110 |
+
return size_bytes(output_size, total_size)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.fx
|
| 7 |
+
|
| 8 |
+
from torch.fx._compatibility import compatibility
|
| 9 |
+
from torch.fx.node import map_arg
|
| 10 |
+
|
| 11 |
+
from .shape_prop import ShapeProp
|
| 12 |
+
from .split_utils import split_by_tags
|
| 13 |
+
from .tools_common import (
|
| 14 |
+
CALLABLE_NODE_OPS,
|
| 15 |
+
FxNetAccFusionsFinder,
|
| 16 |
+
Names,
|
| 17 |
+
NodeList,
|
| 18 |
+
NodeSet,
|
| 19 |
+
TensorOrTensors,
|
| 20 |
+
Tensors,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"FxNetMinimizerBadModuleError",
|
| 25 |
+
"FxNetMinimizerRunFuncError",
|
| 26 |
+
"FxNetMinimizerResultMismatchError",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
_LOGGER = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@compatibility(is_backward_compatible=False)
|
| 33 |
+
class FxNetMinimizerBadModuleError(Exception):
|
| 34 |
+
"""
|
| 35 |
+
Raised if failed to split out a minimize module
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@compatibility(is_backward_compatible=False)
|
| 42 |
+
class FxNetMinimizerRunFuncError(Exception):
|
| 43 |
+
"""
|
| 44 |
+
Raised if error occurs during run_a or run_b functions
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@compatibility(is_backward_compatible=False)
|
| 51 |
+
class FxNetMinimizerResultMismatchError(Exception):
|
| 52 |
+
"""
|
| 53 |
+
Raised if comparing function thinks the results are mismatching.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class _MinimizerSettingBase:
|
| 61 |
+
"""
|
| 62 |
+
Args:
|
| 63 |
+
`accumulate_error`: Instead of using a's input for both converted module to verify
|
| 64 |
+
, use the previous outputs of each converted module as input to accumulate the
|
| 65 |
+
errors.
|
| 66 |
+
|
| 67 |
+
`traverse_method`: "sequential" or "binary" or "accumulate"
|
| 68 |
+
Determine the way of traverse the nodes in FX module.
|
| 69 |
+
|
| 70 |
+
`find_all`: Minimizer will go through the entire model and return all problematic nodes.
|
| 71 |
+
|
| 72 |
+
`return_intermediate`: If true, when using `run_nodes()` function to run the
|
| 73 |
+
model, intermediate results of all the ops will be returned as output.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
accumulate_error: bool = False
|
| 77 |
+
traverse_method: str = "sequential"
|
| 78 |
+
find_all: bool = False
|
| 79 |
+
return_intermediate: bool = False
|
| 80 |
+
|
| 81 |
+
def __str__(self):
|
| 82 |
+
settings_str = "FX Minimizer Settings:\n"
|
| 83 |
+
|
| 84 |
+
for k, v in vars(self).items():
|
| 85 |
+
settings_str += f"\t{k}: {v}\n"
|
| 86 |
+
|
| 87 |
+
return settings_str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class _MinimizerBase:
|
| 91 |
+
"""
|
| 92 |
+
This class is used to automatically find problematic nodes in a model. It takes a FX
|
| 93 |
+
graphmodule and generate some submodules while traverse the graph. Then two functions
|
| 94 |
+
`run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
|
| 95 |
+
will be used to compare the results.
|
| 96 |
+
|
| 97 |
+
Currently we provides two ways to traverse the graph and generate submodules.
|
| 98 |
+
1. Sequential traversal: this will traverse the graph node by node and generate
|
| 99 |
+
one submodule with one sigle node.
|
| 100 |
+
2. Binary searching: this will do a binary search style traversal on the graph.
|
| 101 |
+
|
| 102 |
+
For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
module: torch.fx.GraphModule,
|
| 108 |
+
sample_input: Tensors,
|
| 109 |
+
compare_fn: Callable[
|
| 110 |
+
[TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
|
| 111 |
+
],
|
| 112 |
+
settings: _MinimizerSettingBase,
|
| 113 |
+
module_exporter: Optional[
|
| 114 |
+
Callable[
|
| 115 |
+
[List[torch.Tensor], torch.fx.GraphModule, str],
|
| 116 |
+
None
|
| 117 |
+
]
|
| 118 |
+
] = None,
|
| 119 |
+
):
|
| 120 |
+
assert isinstance(module, torch.fx.GraphModule)
|
| 121 |
+
|
| 122 |
+
self.module = module
|
| 123 |
+
self.sample_input = sample_input
|
| 124 |
+
self.compare_fn = compare_fn
|
| 125 |
+
self.module_exporter = module_exporter
|
| 126 |
+
self.settings = settings
|
| 127 |
+
|
| 128 |
+
# Stores outputs of run_a function
|
| 129 |
+
self.a_outputs: Dict[str, Any] = {}
|
| 130 |
+
|
| 131 |
+
# Stores outputs of run_b function
|
| 132 |
+
self.b_outputs: Dict[str, Any] = {}
|
| 133 |
+
|
| 134 |
+
# Stores the results of compare_fn
|
| 135 |
+
self.results: Dict[Any, Any] = {}
|
| 136 |
+
|
| 137 |
+
# Stores the report for the runs
|
| 138 |
+
self.reports: List[List[str]] = []
|
| 139 |
+
|
| 140 |
+
# Current iteration
|
| 141 |
+
self.iteration: int = 0
|
| 142 |
+
|
| 143 |
+
callable_nodes = {
|
| 144 |
+
node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
|
| 145 |
+
}
|
| 146 |
+
ShapeProp(self.module).propagate(*self.sample_input)
|
| 147 |
+
self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
|
| 148 |
+
|
| 149 |
+
# Check if number of input in sample_input matches the number of placeholders
|
| 150 |
+
placeholders = [
|
| 151 |
+
node.name for node in self.module.graph.nodes if node.op == "placeholder"
|
| 152 |
+
]
|
| 153 |
+
assert len(placeholders) == len(self.sample_input)
|
| 154 |
+
|
| 155 |
+
# Store sample_input
|
| 156 |
+
for i, name in enumerate(placeholders):
|
| 157 |
+
self.a_outputs[name] = sample_input[i]
|
| 158 |
+
self.b_outputs[name] = sample_input[i]
|
| 159 |
+
|
| 160 |
+
def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
|
| 161 |
+
"""
|
| 162 |
+
Run `mod` with `inputs` and generate output. The output will be compared with
|
| 163 |
+
output of run_b().
|
| 164 |
+
"""
|
| 165 |
+
raise RuntimeError("run_a() is not implemented.")
|
| 166 |
+
|
| 167 |
+
def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
|
| 168 |
+
"""
|
| 169 |
+
Run `mod` with `inputs` and generate output. The output will be compared with
|
| 170 |
+
output of run_a().
|
| 171 |
+
"""
|
| 172 |
+
raise RuntimeError("run_b() is not implemented.")
|
| 173 |
+
|
| 174 |
+
def _store_outputs(
|
| 175 |
+
self,
|
| 176 |
+
a_result: TensorOrTensors,
|
| 177 |
+
b_result: TensorOrTensors,
|
| 178 |
+
submodule: torch.fx.GraphModule,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
|
| 182 |
+
self.b_outputs, so that we can use them when execute preceding nodes that
|
| 183 |
+
use those outputs as inputs.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
a_result: Output of self.run_a(). Could be a tensor or tensors.
|
| 187 |
+
b_result: Output of self.run_b(). Could be a tensor or tensors.
|
| 188 |
+
submodule: The module that generates a_result and b_result.
|
| 189 |
+
"""
|
| 190 |
+
output_node = next(
|
| 191 |
+
node for node in submodule.graph.nodes if node.op == "output"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Only one output
|
| 195 |
+
if isinstance(output_node.args[0], torch.fx.Node):
|
| 196 |
+
self.a_outputs[output_node.args[0].name] = a_result
|
| 197 |
+
self.b_outputs[output_node.args[0].name] = b_result
|
| 198 |
+
# Multiple outputs
|
| 199 |
+
else:
|
| 200 |
+
for i, arg in enumerate(output_node.args[0]):
|
| 201 |
+
self.a_outputs[arg.name] = a_result[i]
|
| 202 |
+
self.b_outputs[arg.name] = b_result[i]
|
| 203 |
+
|
| 204 |
+
def _get_submod_inputs(
|
| 205 |
+
self, main_module: torch.fx.GraphModule, submod_path: str
|
| 206 |
+
) -> Tuple[Tensors, Tensors]:
|
| 207 |
+
"""
|
| 208 |
+
Try get submodule inputs from stored outputs. If not found then use
|
| 209 |
+
torch_glow.get_submod_inputs to get the inputs.
|
| 210 |
+
|
| 211 |
+
If accumulate_error is False, use a_input for run_a() and run_b()
|
| 212 |
+
otherwise use a_input for run_a and b_input for run_b.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
main_module: Top-levlel fx module.
|
| 216 |
+
submod_path: Path to the submodule we want to run and compare results.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
|
| 220 |
+
b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
|
| 221 |
+
"""
|
| 222 |
+
a_input = []
|
| 223 |
+
b_input = []
|
| 224 |
+
submodule = getattr(main_module, submod_path)
|
| 225 |
+
placeholders = [
|
| 226 |
+
node.name for node in submodule.graph.nodes if node.op == "placeholder"
|
| 227 |
+
]
|
| 228 |
+
|
| 229 |
+
# If all placeholder can be found in stored outputs, use stored
|
| 230 |
+
# outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
|
| 231 |
+
# to get the inputs.
|
| 232 |
+
if set(placeholders) <= self.a_outputs.keys():
|
| 233 |
+
for name in placeholders:
|
| 234 |
+
a_input.append(self.a_outputs[name])
|
| 235 |
+
b_input.append(self.b_outputs[name])
|
| 236 |
+
else:
|
| 237 |
+
if self.settings.accumulate_error:
|
| 238 |
+
print(f"Can't find previous stored outputs named {placeholders}!")
|
| 239 |
+
|
| 240 |
+
def get_inputs(self: torch.nn.Module, inputs: Any):
|
| 241 |
+
nonlocal a_input
|
| 242 |
+
a_input = inputs
|
| 243 |
+
|
| 244 |
+
# Use forward hook to get the inputs to the submodule
|
| 245 |
+
handle = submodule.register_forward_pre_hook(get_inputs)
|
| 246 |
+
main_module(*self.sample_input)
|
| 247 |
+
handle.remove()
|
| 248 |
+
|
| 249 |
+
b_input = a_input
|
| 250 |
+
|
| 251 |
+
if not self.settings.accumulate_error:
|
| 252 |
+
return a_input, a_input
|
| 253 |
+
|
| 254 |
+
return a_input, b_input
|
| 255 |
+
|
| 256 |
+
def _tag_nodes(self, selected_nodes: NodeSet):
|
| 257 |
+
"""
|
| 258 |
+
Tag selected nodes with tag "minimize". Nodes with the same tags will
|
| 259 |
+
be split to the same submodule afterwards.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
selected_nodes: Nodes that we want to minimize. We will tag those nodes
|
| 263 |
+
with "minimize", all preceding nodes with "main_0" and all following
|
| 264 |
+
nodes with "main_1".
|
| 265 |
+
"""
|
| 266 |
+
for node in self.module.graph.nodes:
|
| 267 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 268 |
+
continue
|
| 269 |
+
|
| 270 |
+
if node in selected_nodes:
|
| 271 |
+
node.tag = "minimize"
|
| 272 |
+
elif any(
|
| 273 |
+
n.tag in {"minimize", "main_1"}
|
| 274 |
+
for n in node.all_input_nodes
|
| 275 |
+
if n.op in CALLABLE_NODE_OPS
|
| 276 |
+
):
|
| 277 |
+
node.tag = "main_1"
|
| 278 |
+
else:
|
| 279 |
+
node.tag = "main_0"
|
| 280 |
+
|
| 281 |
+
def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
|
| 282 |
+
"""
|
| 283 |
+
Split self.module so that one submodule consists of `nodes` and only `nodes`.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
nodes: Nodes that we want to include in the minimize submodule.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
split_module (torch.fx.GraphModule): the module after split.
|
| 290 |
+
submodule_name (str): the name of the submodule that consists of `nodes`.
|
| 291 |
+
"""
|
| 292 |
+
# Color provided nodes
|
| 293 |
+
self._tag_nodes(nodes)
|
| 294 |
+
|
| 295 |
+
# Split module based on coloring
|
| 296 |
+
split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
|
| 297 |
+
|
| 298 |
+
# Find submodule containing colored nodes
|
| 299 |
+
submodule_name: str = ""
|
| 300 |
+
for child_name, _ in split_module.named_children():
|
| 301 |
+
# Skip submodules we're not interested in at the moment
|
| 302 |
+
if "minimize" not in child_name:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
if submodule_name == "":
|
| 306 |
+
submodule_name = child_name
|
| 307 |
+
else:
|
| 308 |
+
raise FxNetMinimizerBadModuleError(
|
| 309 |
+
f"Expected only one minimize submodule with nodes {nodes}"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if submodule_name == "":
|
| 313 |
+
raise FxNetMinimizerBadModuleError(
|
| 314 |
+
f"Minimize submodule was not found with nodes {nodes}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return split_module, submodule_name
|
| 318 |
+
|
| 319 |
+
def _run_and_compare(
|
| 320 |
+
self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
Run the submodule in `split_module` that has name `submod_name`
|
| 324 |
+
using `self.run_a` and `self.run_b` and compare their results.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
split_module: Main module that contains the minimize submodule.
|
| 328 |
+
submod_name: Name of the minimize submodule.
|
| 329 |
+
output_names: Names of the node we want to output. If None, we
|
| 330 |
+
will use the original output.
|
| 331 |
+
"""
|
| 332 |
+
submodule = getattr(split_module, submod_name)
|
| 333 |
+
a_input, b_input = self._get_submod_inputs(split_module, submod_name)
|
| 334 |
+
|
| 335 |
+
if len(self.reports) == 0:
|
| 336 |
+
self.reports.append([])
|
| 337 |
+
self.iteration = 1
|
| 338 |
+
|
| 339 |
+
report = self.reports[self.iteration - 1]
|
| 340 |
+
report.append("Run and compare ...")
|
| 341 |
+
|
| 342 |
+
if output_names:
|
| 343 |
+
output_nodes: NodeList = []
|
| 344 |
+
for node in submodule.graph.nodes:
|
| 345 |
+
if node.op == "output":
|
| 346 |
+
submodule.graph.erase_node(node)
|
| 347 |
+
|
| 348 |
+
if node.name in output_names:
|
| 349 |
+
output_nodes.append(node)
|
| 350 |
+
|
| 351 |
+
submodule.graph.output(
|
| 352 |
+
output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
|
| 353 |
+
)
|
| 354 |
+
submodule.graph.lint()
|
| 355 |
+
submodule.recompile()
|
| 356 |
+
|
| 357 |
+
# Use name of args in output node as key to store comparison result
|
| 358 |
+
for node in submodule.graph.nodes:
|
| 359 |
+
if node.op == "output":
|
| 360 |
+
result_key = map_arg(node.args, lambda x: x.name)
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
a_result = self.run_a(submodule, a_input)
|
| 364 |
+
b_result = self.run_b(submodule, b_input)
|
| 365 |
+
self._store_outputs(a_result, b_result, submodule)
|
| 366 |
+
except Exception as e:
|
| 367 |
+
report.append(f"Exception raised when running {submod_name}: {e}")
|
| 368 |
+
raise FxNetMinimizerRunFuncError( # noqa: TRY200
|
| 369 |
+
f"Exception raised when running {submod_name}: {e}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Compare results
|
| 373 |
+
names: Names = output_names
|
| 374 |
+
if output_names is None:
|
| 375 |
+
names = [str(v) for v in result_key] # type: ignore[possibly-undefined]
|
| 376 |
+
|
| 377 |
+
numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
|
| 378 |
+
|
| 379 |
+
self.results[result_key] = numeric_result # type: ignore[possibly-undefined]
|
| 380 |
+
report.append(f"Numerical accuracy = {numeric_result}")
|
| 381 |
+
if not bool_result:
|
| 382 |
+
report.append(f"Result mismatch for {result_key}")
|
| 383 |
+
if self.module_exporter:
|
| 384 |
+
self.module_exporter(
|
| 385 |
+
List[torch.Tensor](a_input), submodule, str(result_key[0]) + "_cpu",
|
| 386 |
+
)
|
| 387 |
+
self.module_exporter(
|
| 388 |
+
List[torch.Tensor](b_input), submodule, str(result_key[0]) + "_acc",
|
| 389 |
+
)
|
| 390 |
+
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
|
| 391 |
+
|
| 392 |
+
def _binary_search_impl(
|
| 393 |
+
self, all_nodes: NodeList, start_idx: int, end_idx: int
|
| 394 |
+
) -> NodeSet:
|
| 395 |
+
"""
|
| 396 |
+
Recursive binary search implementation.
|
| 397 |
+
"""
|
| 398 |
+
nodes: NodeList = all_nodes[start_idx:end_idx]
|
| 399 |
+
|
| 400 |
+
report: List[str] = []
|
| 401 |
+
self.reports.append(report)
|
| 402 |
+
self.iteration += 1
|
| 403 |
+
report.append(f"Binary search iteration {self.iteration}.")
|
| 404 |
+
report.append(
|
| 405 |
+
f"From node index {start_idx} to {end_idx-1}. "
|
| 406 |
+
f"Size of the interested node list is {len(nodes)}"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
cur_nodes: NodeSet = set(nodes)
|
| 410 |
+
|
| 411 |
+
for node in nodes:
|
| 412 |
+
if node in self.fusions:
|
| 413 |
+
cur_nodes.update(self.fusions[node])
|
| 414 |
+
|
| 415 |
+
try:
|
| 416 |
+
split_module, submod_name = self._build_submodule(cur_nodes)
|
| 417 |
+
self._run_and_compare(split_module, submod_name, [])
|
| 418 |
+
except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
|
| 419 |
+
|
| 420 |
+
if len(nodes) == 1:
|
| 421 |
+
report.append(
|
| 422 |
+
f"This is the last node in the sub-module. "
|
| 423 |
+
f"Search in the current branch is successful with culprit = {cur_nodes}."
|
| 424 |
+
)
|
| 425 |
+
self.print_report(report)
|
| 426 |
+
return cur_nodes
|
| 427 |
+
|
| 428 |
+
report.append(
|
| 429 |
+
"Proceed to split and lower the halves of the current "
|
| 430 |
+
"sub-module individually."
|
| 431 |
+
)
|
| 432 |
+
self.print_report(report)
|
| 433 |
+
|
| 434 |
+
mid = len(nodes) // 2
|
| 435 |
+
culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
|
| 436 |
+
|
| 437 |
+
if len(culprits) != 0 and not self.settings.find_all:
|
| 438 |
+
return culprits
|
| 439 |
+
|
| 440 |
+
culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
|
| 441 |
+
|
| 442 |
+
if len(culprits) == 0:
|
| 443 |
+
report.append(
|
| 444 |
+
f"Further split and lowering found no errors. "
|
| 445 |
+
f"Unable to minimize the submodule with list of nodes: {nodes}"
|
| 446 |
+
)
|
| 447 |
+
self.print_report(report)
|
| 448 |
+
|
| 449 |
+
return culprits
|
| 450 |
+
else:
|
| 451 |
+
report.append("No discrepancy found.")
|
| 452 |
+
self.print_report(report)
|
| 453 |
+
return set()
|
| 454 |
+
|
| 455 |
+
def _binary_traverse(self, nodes: NodeList) -> NodeSet:
|
| 456 |
+
"""
|
| 457 |
+
Binary search on `nodes` for culprit.
|
| 458 |
+
"""
|
| 459 |
+
return self._binary_search_impl(nodes, 0, len(nodes))
|
| 460 |
+
|
| 461 |
+
def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
|
| 462 |
+
"""
|
| 463 |
+
Traverse `nodes` one by one and determine if any of them is a culprit.
|
| 464 |
+
"""
|
| 465 |
+
culprits: NodeSet = set()
|
| 466 |
+
|
| 467 |
+
for node in nodes:
|
| 468 |
+
report: List[str] = []
|
| 469 |
+
self.reports.append(report)
|
| 470 |
+
self.iteration += 1
|
| 471 |
+
report.append(f"Sequential traverse iteration {self.iteration}.")
|
| 472 |
+
report.append(f"Visit node: {node.name}")
|
| 473 |
+
|
| 474 |
+
_LOGGER.info("Visit node: %s", node.name)
|
| 475 |
+
cur_nodes: NodeSet = {node}
|
| 476 |
+
|
| 477 |
+
if node in self.fusions:
|
| 478 |
+
cur_nodes = self.fusions[node]
|
| 479 |
+
|
| 480 |
+
try:
|
| 481 |
+
split_module, submod_name = self._build_submodule(cur_nodes)
|
| 482 |
+
self._run_and_compare(split_module, submod_name, [node.name])
|
| 483 |
+
self.print_report(report)
|
| 484 |
+
except (FxNetMinimizerResultMismatchError):
|
| 485 |
+
culprits.add(node)
|
| 486 |
+
report.append(f"Found culprit from numeric error: {node}")
|
| 487 |
+
self.print_report(report)
|
| 488 |
+
if not self.settings.find_all:
|
| 489 |
+
return culprits
|
| 490 |
+
except (FxNetMinimizerRunFuncError):
|
| 491 |
+
culprits.update(cur_nodes)
|
| 492 |
+
report.append(f"Found culprit from run error: {node}")
|
| 493 |
+
self.print_report(report)
|
| 494 |
+
if not self.settings.find_all:
|
| 495 |
+
return culprits
|
| 496 |
+
|
| 497 |
+
return culprits
|
| 498 |
+
|
| 499 |
+
def _defined_traverse(self, nodes: NodeList) -> NodeSet:
|
| 500 |
+
"""
|
| 501 |
+
run user defined `nodes` and determine if it is a culprit.
|
| 502 |
+
"""
|
| 503 |
+
culprits: NodeSet = set()
|
| 504 |
+
|
| 505 |
+
first_node_name = nodes[0].name
|
| 506 |
+
output_node_name = nodes[-1].name
|
| 507 |
+
report = [f"Defined graph from {first_node_name} to {output_node_name}"]
|
| 508 |
+
cur_nodes: NodeSet = set(nodes)
|
| 509 |
+
try:
|
| 510 |
+
split_module, submod_name = self._build_submodule(cur_nodes)
|
| 511 |
+
self._run_and_compare(split_module, submod_name, [output_node_name])
|
| 512 |
+
self.print_report(report)
|
| 513 |
+
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
|
| 514 |
+
report.append(f"Found culprit {cur_nodes}")
|
| 515 |
+
self.print_report(report)
|
| 516 |
+
return culprits
|
| 517 |
+
|
| 518 |
+
return culprits
|
| 519 |
+
|
| 520 |
+
def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
|
| 521 |
+
culprits: NodeSet = set()
|
| 522 |
+
nodes_to_run: NodeSet = set()
|
| 523 |
+
|
| 524 |
+
# find_all is not supported for accumulate traversal because all the
|
| 525 |
+
# ops run on NNPI. So we return after the first op that raises error.
|
| 526 |
+
if self.settings.find_all:
|
| 527 |
+
print("'Find All' mode is not supported in accumulate traversal.")
|
| 528 |
+
return culprits
|
| 529 |
+
|
| 530 |
+
for node in nodes:
|
| 531 |
+
report: List[str] = []
|
| 532 |
+
self.reports.append(report)
|
| 533 |
+
self.iteration += 1
|
| 534 |
+
report.append(f"Accumulate traverse iteration {self.iteration}.")
|
| 535 |
+
|
| 536 |
+
nodes_to_run.add(node)
|
| 537 |
+
|
| 538 |
+
node_name = node.name
|
| 539 |
+
if node_name is not None and isinstance(node_name, tuple):
|
| 540 |
+
node_name = node_name[0]
|
| 541 |
+
assert node_name is not None and isinstance(
|
| 542 |
+
node_name, str
|
| 543 |
+
), f"minimize: node_name: {node_name}"
|
| 544 |
+
|
| 545 |
+
report.append(f"Add node: {node_name}")
|
| 546 |
+
|
| 547 |
+
try:
|
| 548 |
+
split_module, submod_name = self._build_submodule(nodes_to_run)
|
| 549 |
+
self._run_and_compare(split_module, submod_name, [node_name])
|
| 550 |
+
self.print_report(report)
|
| 551 |
+
except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
|
| 552 |
+
culprits.add(node)
|
| 553 |
+
report.append(f"Found culprit {node}")
|
| 554 |
+
self.print_report(report)
|
| 555 |
+
return culprits
|
| 556 |
+
|
| 557 |
+
return culprits
|
| 558 |
+
|
| 559 |
+
def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet:
|
| 560 |
+
"""
|
| 561 |
+
Skip certain nodes in graph based on settings
|
| 562 |
+
"""
|
| 563 |
+
culprits: NodeSet = set()
|
| 564 |
+
nodes: NodeList = all_nodes[start_idx:end_idx]
|
| 565 |
+
|
| 566 |
+
report: List[str] = []
|
| 567 |
+
self.reports.append(report)
|
| 568 |
+
self.iteration += 1
|
| 569 |
+
report.append(f" Nodes block {self.iteration}.")
|
| 570 |
+
report.append(
|
| 571 |
+
f"From node index {start_idx} to {end_idx-1}. "
|
| 572 |
+
f"Size of the interested node list is {len(nodes)}"
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
cur_nodes: NodeSet = set(nodes)
|
| 576 |
+
|
| 577 |
+
for node in nodes:
|
| 578 |
+
if node in self.fusions:
|
| 579 |
+
cur_nodes.update(self.fusions[node])
|
| 580 |
+
|
| 581 |
+
try:
|
| 582 |
+
split_module, submod_name = self._build_submodule(cur_nodes)
|
| 583 |
+
self._run_and_compare(split_module, submod_name, [])
|
| 584 |
+
except (FxNetMinimizerResultMismatchError):
|
| 585 |
+
culprits.update(cur_nodes)
|
| 586 |
+
report.append(f"Found culprit from numeric error: {cur_nodes}")
|
| 587 |
+
self.print_report(report)
|
| 588 |
+
return culprits
|
| 589 |
+
except (FxNetMinimizerRunFuncError):
|
| 590 |
+
culprits.update(cur_nodes)
|
| 591 |
+
report.append(f"Found culprit from run error: {node}")
|
| 592 |
+
self.print_report(report)
|
| 593 |
+
return culprits
|
| 594 |
+
else:
|
| 595 |
+
report.append("No discrepancy found.")
|
| 596 |
+
self.print_report(report)
|
| 597 |
+
return set()
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet:
|
| 601 |
+
"""
|
| 602 |
+
Skip certain nodes in graph based on settings
|
| 603 |
+
"""
|
| 604 |
+
start_idx = 0
|
| 605 |
+
num_nodes = len(all_nodes)
|
| 606 |
+
idx = 0
|
| 607 |
+
culprits = set()
|
| 608 |
+
while idx < num_nodes:
|
| 609 |
+
node = all_nodes[idx]
|
| 610 |
+
if (node.name in skip_nodes): # skip the node
|
| 611 |
+
if idx > start_idx:
|
| 612 |
+
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx)
|
| 613 |
+
start_idx = idx + 1
|
| 614 |
+
elif idx == num_nodes - 1 and start_idx <= idx: # last node
|
| 615 |
+
culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1)
|
| 616 |
+
idx += 1
|
| 617 |
+
|
| 618 |
+
return culprits
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
|
| 623 |
+
"""
|
| 624 |
+
Collect nodes in the model that between nodes with name of `start` and `end`.
|
| 625 |
+
These two nodes are also included.
|
| 626 |
+
"""
|
| 627 |
+
nodes: NodeList = []
|
| 628 |
+
add_node = start is None
|
| 629 |
+
|
| 630 |
+
for node in self.module.graph.nodes:
|
| 631 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 632 |
+
continue
|
| 633 |
+
|
| 634 |
+
if node.name == start:
|
| 635 |
+
add_node = True
|
| 636 |
+
|
| 637 |
+
if add_node:
|
| 638 |
+
nodes.append(node)
|
| 639 |
+
|
| 640 |
+
if node.name == end:
|
| 641 |
+
break
|
| 642 |
+
|
| 643 |
+
return nodes
|
| 644 |
+
|
| 645 |
+
def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
|
| 646 |
+
"""
|
| 647 |
+
Run part of the model from `start` node to `end` node. If `start` is None
|
| 648 |
+
then we start from the beginning of the model. If `end` is None then we
|
| 649 |
+
stop at the end of the model.
|
| 650 |
+
|
| 651 |
+
Args:
|
| 652 |
+
start: The name of the node which is the first node of the submodule
|
| 653 |
+
we want to run. If set to None, then we'll start with the first
|
| 654 |
+
node of the model.
|
| 655 |
+
end: The name of the node which is the last node of the submodule we
|
| 656 |
+
want to run. If set to None, we'll end with the last node of the
|
| 657 |
+
model.
|
| 658 |
+
"""
|
| 659 |
+
nodes = self._collect_nodes(start, end)
|
| 660 |
+
cur_nodes = set(nodes)
|
| 661 |
+
|
| 662 |
+
for node in nodes:
|
| 663 |
+
if node in self.fusions:
|
| 664 |
+
cur_nodes.update(self.fusions[node])
|
| 665 |
+
|
| 666 |
+
output_names = []
|
| 667 |
+
if self.settings.return_intermediate:
|
| 668 |
+
output_names = [node.name for node in nodes]
|
| 669 |
+
|
| 670 |
+
try:
|
| 671 |
+
split_module, submod_name = self._build_submodule(cur_nodes)
|
| 672 |
+
self._run_and_compare(split_module, submod_name, output_names)
|
| 673 |
+
except (
|
| 674 |
+
FxNetMinimizerRunFuncError,
|
| 675 |
+
FxNetMinimizerResultMismatchError,
|
| 676 |
+
) as e:
|
| 677 |
+
print(e)
|
| 678 |
+
|
| 679 |
+
def print_report(self, report: List[str]):
|
| 680 |
+
for i in range(len(report)):
|
| 681 |
+
if i > 0:
|
| 682 |
+
print(" . " + report[i])
|
| 683 |
+
else:
|
| 684 |
+
print(report[i])
|
| 685 |
+
|
| 686 |
+
def print_reports(self):
|
| 687 |
+
for report in self.reports:
|
| 688 |
+
self.print_report(report)
|
| 689 |
+
|
| 690 |
+
def minimize(
|
| 691 |
+
self, start: Optional[str] = None, end: Optional[str] = None, skip_nodes: Optional[List] = None,
|
| 692 |
+
) -> NodeSet:
|
| 693 |
+
"""
|
| 694 |
+
Minimizing the model from node with name `start` to node with name `end` base
|
| 695 |
+
on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
|
| 696 |
+
FxNetMinimizerResultMismatchError errors.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
start: The name of the node where we want to start minimizing. If set
|
| 700 |
+
to None, then we'll start with the first node of the model.
|
| 701 |
+
end: The name of the node where we want to terminate minimizing. If
|
| 702 |
+
set to None, we'll end with the last node of the model.
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
|
| 706 |
+
FxNetMinimizerResultMismatchError errors during minimizing.
|
| 707 |
+
"""
|
| 708 |
+
|
| 709 |
+
print(self.settings)
|
| 710 |
+
print(self.module.graph)
|
| 711 |
+
|
| 712 |
+
nodes = self._collect_nodes(start, end)
|
| 713 |
+
|
| 714 |
+
if self.settings.traverse_method == "sequential":
|
| 715 |
+
return self._sequential_traverse(nodes)
|
| 716 |
+
|
| 717 |
+
if self.settings.traverse_method == "binary":
|
| 718 |
+
return self._binary_traverse(nodes)
|
| 719 |
+
|
| 720 |
+
if self.settings.traverse_method == "accumulate":
|
| 721 |
+
return self._accumulate_traverse(nodes)
|
| 722 |
+
|
| 723 |
+
if self.settings.traverse_method == "skip":
|
| 724 |
+
if (skip_nodes is None):
|
| 725 |
+
raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.")
|
| 726 |
+
return self._skip_traverse(nodes, skip_nodes)
|
| 727 |
+
|
| 728 |
+
if self.settings.traverse_method == "defined":
|
| 729 |
+
return self._defined_traverse(nodes)
|
| 730 |
+
|
| 731 |
+
raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from inspect import unwrap
|
| 3 |
+
from typing import Callable, List, Optional
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"PassManager",
|
| 10 |
+
"inplace_wrapper",
|
| 11 |
+
"log_hook",
|
| 12 |
+
"loop_pass",
|
| 13 |
+
"this_before_that_pass_constraint",
|
| 14 |
+
"these_before_those_pass_constraint",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
# for callables which modify object inplace and return something other than
|
| 18 |
+
# the object on which they act
|
| 19 |
+
def inplace_wrapper(fn: Callable) -> Callable:
|
| 20 |
+
"""
|
| 21 |
+
Convenience wrapper for passes which modify an object inplace. This
|
| 22 |
+
wrapper makes them return the modified object instead.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
fn (Callable[Object, Any])
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
wrapped_fn (Callable[Object, Object])
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
@wraps(fn)
|
| 32 |
+
def wrapped_fn(gm):
|
| 33 |
+
val = fn(gm)
|
| 34 |
+
return gm
|
| 35 |
+
|
| 36 |
+
return wrapped_fn
|
| 37 |
+
|
| 38 |
+
def log_hook(fn: Callable, level=logging.INFO) -> Callable:
|
| 39 |
+
"""
|
| 40 |
+
Logs callable output.
|
| 41 |
+
|
| 42 |
+
This is useful for logging output of passes. Note inplace_wrapper replaces
|
| 43 |
+
the pass output with the modified object. If we want to log the original
|
| 44 |
+
output, apply this wrapper before inplace_wrapper.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
def my_pass(d: Dict) -> bool:
|
| 49 |
+
changed = False
|
| 50 |
+
if 'foo' in d:
|
| 51 |
+
d['foo'] = 'bar'
|
| 52 |
+
changed = True
|
| 53 |
+
return changed
|
| 54 |
+
|
| 55 |
+
pm = PassManager(
|
| 56 |
+
passes=[
|
| 57 |
+
inplace_wrapper(log_hook(my_pass))
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
fn (Callable[Type1, Type2])
|
| 64 |
+
level: logging level (e.g. logging.INFO)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
wrapped_fn (Callable[Type1, Type2])
|
| 68 |
+
"""
|
| 69 |
+
@wraps(fn)
|
| 70 |
+
def wrapped_fn(gm):
|
| 71 |
+
val = fn(gm)
|
| 72 |
+
logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
|
| 73 |
+
return val
|
| 74 |
+
|
| 75 |
+
return wrapped_fn
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None):
|
| 80 |
+
"""
|
| 81 |
+
Convenience wrapper for passes which need to be applied multiple times.
|
| 82 |
+
|
| 83 |
+
Exactly one of `n_iter`or `predicate` must be specified.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
base_pass (Callable[Object, Object]): pass to be applied in loop
|
| 87 |
+
n_iter (int, optional): number of times to loop pass
|
| 88 |
+
predicate (Callable[Object, bool], optional):
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
assert (n_iter is not None) ^ (
|
| 92 |
+
predicate is not None
|
| 93 |
+
), "Exactly one of `n_iter`or `predicate` must be specified."
|
| 94 |
+
|
| 95 |
+
@wraps(base_pass)
|
| 96 |
+
def new_pass(source):
|
| 97 |
+
output = source
|
| 98 |
+
if n_iter is not None and n_iter > 0:
|
| 99 |
+
for _ in range(n_iter):
|
| 100 |
+
output = base_pass(output)
|
| 101 |
+
elif predicate is not None:
|
| 102 |
+
while predicate(output):
|
| 103 |
+
output = base_pass(output)
|
| 104 |
+
else:
|
| 105 |
+
raise RuntimeError(
|
| 106 |
+
f"loop_pass must be given positive int n_iter (given "
|
| 107 |
+
f"{n_iter}) xor predicate (given {predicate})"
|
| 108 |
+
)
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
return new_pass
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# Pass Schedule Constraints:
|
| 115 |
+
#
|
| 116 |
+
# Implemented as 'depends on' operators. A constraint is satisfied iff a list
|
| 117 |
+
# has a valid partial ordering according to this comparison operator.
|
| 118 |
+
def _validate_pass_schedule_constraint(
|
| 119 |
+
constraint: Callable[[Callable, Callable], bool], passes: List[Callable]
|
| 120 |
+
):
|
| 121 |
+
for i, a in enumerate(passes):
|
| 122 |
+
for j, b in enumerate(passes[i + 1 :]):
|
| 123 |
+
if constraint(a, b):
|
| 124 |
+
continue
|
| 125 |
+
raise RuntimeError(
|
| 126 |
+
f"pass schedule constraint violated. Expected {a} before {b}"
|
| 127 |
+
f" but found {a} at index {i} and {b} at index{j} in pass"
|
| 128 |
+
f" list."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def this_before_that_pass_constraint(this: Callable, that: Callable):
|
| 133 |
+
"""
|
| 134 |
+
Defines a partial order ('depends on' function) where `this` must occur
|
| 135 |
+
before `that`.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def depends_on(a: Callable, b: Callable):
|
| 139 |
+
if a == that and b == this:
|
| 140 |
+
return False
|
| 141 |
+
return True
|
| 142 |
+
|
| 143 |
+
return depends_on
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def these_before_those_pass_constraint(these: Callable, those: Callable):
|
| 147 |
+
"""
|
| 148 |
+
Defines a partial order ('depends on' function) where `these` must occur
|
| 149 |
+
before `those`. Where the inputs are 'unwrapped' before comparison.
|
| 150 |
+
|
| 151 |
+
For example, the following pass list and constraint list would be invalid.
|
| 152 |
+
```
|
| 153 |
+
passes = [
|
| 154 |
+
loop_pass(pass_b, 3),
|
| 155 |
+
loop_pass(pass_a, 5),
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
constraints = [
|
| 159 |
+
these_before_those_pass_constraint(pass_a, pass_b)
|
| 160 |
+
]
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
these (Callable): pass which should occur first
|
| 165 |
+
those (Callable): pass which should occur later
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
depends_on (Callable[[Object, Object], bool]
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def depends_on(a: Callable, b: Callable):
|
| 172 |
+
if unwrap(a) == those and unwrap(b) == these:
|
| 173 |
+
return False
|
| 174 |
+
return True
|
| 175 |
+
|
| 176 |
+
return depends_on
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class PassManager:
|
| 180 |
+
"""
|
| 181 |
+
Construct a PassManager.
|
| 182 |
+
|
| 183 |
+
Collects passes and constraints. This defines the pass schedule, manages
|
| 184 |
+
pass constraints and pass execution.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
passes (Optional[List[Callable]]): list of passes. A pass is a
|
| 188 |
+
callable which modifies an object and returns modified object
|
| 189 |
+
constraint (Optional[List[Callable]]): list of constraints. A
|
| 190 |
+
constraint is a callable which takes two passes (A, B) and returns
|
| 191 |
+
True if A depends on B and False otherwise. See implementation of
|
| 192 |
+
`this_before_that_pass_constraint` for example.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
passes: List[Callable]
|
| 196 |
+
constraints: List[Callable]
|
| 197 |
+
_validated: bool = False
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
passes=None,
|
| 202 |
+
constraints=None,
|
| 203 |
+
):
|
| 204 |
+
self.passes = passes or []
|
| 205 |
+
self.constraints = constraints or []
|
| 206 |
+
|
| 207 |
+
@classmethod
|
| 208 |
+
def build_from_passlist(cls, passes):
|
| 209 |
+
pm = PassManager(passes)
|
| 210 |
+
# TODO(alexbeloi): add constraint management/validation
|
| 211 |
+
return pm
|
| 212 |
+
|
| 213 |
+
def add_pass(self, _pass: Callable):
|
| 214 |
+
self.passes.append(_pass)
|
| 215 |
+
self._validated = False
|
| 216 |
+
|
| 217 |
+
def add_constraint(self, constraint):
|
| 218 |
+
self.constraints.append(constraint)
|
| 219 |
+
self._validated = False
|
| 220 |
+
|
| 221 |
+
def remove_pass(self, _passes: List[str]):
|
| 222 |
+
if _passes is None:
|
| 223 |
+
return
|
| 224 |
+
passes_left = []
|
| 225 |
+
for ps in self.passes:
|
| 226 |
+
if ps.__name__ not in _passes:
|
| 227 |
+
passes_left.append(ps)
|
| 228 |
+
self.passes = passes_left
|
| 229 |
+
self._validated = False
|
| 230 |
+
|
| 231 |
+
def replace_pass(self, _target, _replacement):
|
| 232 |
+
passes_left = []
|
| 233 |
+
for ps in self.passes:
|
| 234 |
+
if ps.__name__ == _target.__name__:
|
| 235 |
+
passes_left.append(_replacement)
|
| 236 |
+
else:
|
| 237 |
+
passes_left.append(ps)
|
| 238 |
+
self.passes = passes_left
|
| 239 |
+
self._validated = False
|
| 240 |
+
|
| 241 |
+
def validate(self):
|
| 242 |
+
"""
|
| 243 |
+
Validates that current pass schedule defined by `self.passes` is valid
|
| 244 |
+
according to all constraints in `self.constraints`
|
| 245 |
+
"""
|
| 246 |
+
if self._validated:
|
| 247 |
+
return
|
| 248 |
+
for constraint in self.constraints:
|
| 249 |
+
_validate_pass_schedule_constraint(constraint, self.passes)
|
| 250 |
+
self._validated = True
|
| 251 |
+
|
| 252 |
+
def __call__(self, source):
|
| 253 |
+
self.validate()
|
| 254 |
+
out = source
|
| 255 |
+
for _pass in self.passes:
|
| 256 |
+
out = _pass(out)
|
| 257 |
+
return out
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.fx._compatibility import compatibility
|
| 8 |
+
from torch.fx.graph_module import GraphModule
|
| 9 |
+
from torch.fx.node import Node
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
import sympy # noqa: F401
|
| 13 |
+
|
| 14 |
+
__all__ = ["Partition", "split_module"]
|
| 15 |
+
_LOGGER = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
@compatibility(is_backward_compatible=True)
|
| 18 |
+
class Partition:
|
| 19 |
+
def __init__(self, name: str):
|
| 20 |
+
self.name: str = name
|
| 21 |
+
self.submod_name = f"submod_{name}"
|
| 22 |
+
self.node_names: List[str] = []
|
| 23 |
+
self.inputs: Dict[str, None] = {}
|
| 24 |
+
self.outputs: Dict[str, None] = {}
|
| 25 |
+
self.dependencies: Dict[str, None] = {}
|
| 26 |
+
self.dependents: Dict[str, None] = {}
|
| 27 |
+
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
| 28 |
+
self.environment: Dict[Node, Node] = {}
|
| 29 |
+
self.targets: Dict[str, Any] = {}
|
| 30 |
+
|
| 31 |
+
def __repr__(self) -> str:
|
| 32 |
+
return (
|
| 33 |
+
f"name: {self.name},\n"
|
| 34 |
+
f" nodes: {self.node_names},\n"
|
| 35 |
+
f" inputs: {self.inputs},\n"
|
| 36 |
+
f" outputs: {self.outputs},\n"
|
| 37 |
+
f" partitions depended on: {self.dependencies},\n"
|
| 38 |
+
f" partition dependents: {self.dependents}"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Creates subgraphs out of main graph
|
| 43 |
+
@compatibility(is_backward_compatible=True)
|
| 44 |
+
def split_module(
|
| 45 |
+
m: GraphModule,
|
| 46 |
+
root_m: torch.nn.Module,
|
| 47 |
+
split_callback: Callable[[Node], int],
|
| 48 |
+
qualname_map: Optional[Dict[str, str]] = None,
|
| 49 |
+
keep_original_order: Optional[bool] = False,
|
| 50 |
+
keep_original_node_name: Optional[bool] = False,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Creates subgraphs out of main graph
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
m (GraphModule): Graph module to split
|
| 57 |
+
root_m (torch.nn.Module): root nn module. Not currently used. Included
|
| 58 |
+
because the root nn module is usually transformed via
|
| 59 |
+
torch.fx._symbolic_trace.symbolic_trace (see example below)
|
| 60 |
+
split_callback (Callable[[Node], int]): Callable function
|
| 61 |
+
that maps a given Node instance to a numeric partition identifier.
|
| 62 |
+
split_module will use this function as the policy for which operations
|
| 63 |
+
appear in which partitions in the output Module.
|
| 64 |
+
qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
|
| 65 |
+
mapping from new target names in the module after split to old target
|
| 66 |
+
names in the original module.
|
| 67 |
+
keep_original_order: Optional[bool]: keep the original order of the GraphModule
|
| 68 |
+
or use the Topological order of the new constructed GraphModule
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
GraphModule: the module after split.
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
|
| 76 |
+
This is a sample setup:
|
| 77 |
+
|
| 78 |
+
import torch
|
| 79 |
+
from torch.fx.symbolic_trace import symbolic_trace
|
| 80 |
+
from torch.fx.graph_module import GraphModule
|
| 81 |
+
from torch.fx.node import Node
|
| 82 |
+
from torch.fx.passes.split_module import split_module
|
| 83 |
+
|
| 84 |
+
class MyModule(torch.nn.Module):
|
| 85 |
+
def __init__(self):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
| 88 |
+
self.linear = torch.nn.Linear(4, 5)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, y):
|
| 91 |
+
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
| 92 |
+
w = self.linear(y).clamp(min=0.0, max=1.0)
|
| 93 |
+
return z + w
|
| 94 |
+
|
| 95 |
+
# symbolically trace model
|
| 96 |
+
my_module = MyModule()
|
| 97 |
+
my_module_traced = symbolic_trace(my_module)
|
| 98 |
+
|
| 99 |
+
# random mod partitioning
|
| 100 |
+
partition_counter = 0
|
| 101 |
+
NPARTITIONS = 3
|
| 102 |
+
|
| 103 |
+
def mod_partition(node: Node):
|
| 104 |
+
global partition_counter
|
| 105 |
+
partition = partition_counter % NPARTITIONS
|
| 106 |
+
partition_counter = (partition_counter + 1) % NPARTITIONS
|
| 107 |
+
return partition
|
| 108 |
+
|
| 109 |
+
# split module in module with submodules
|
| 110 |
+
module_with_submodules = split_module(
|
| 111 |
+
my_module_traced, my_module, mod_partition
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
Output looks like this. Original graph is broken into partitions
|
| 115 |
+
|
| 116 |
+
> print(module_with_submodules)
|
| 117 |
+
GraphModule(
|
| 118 |
+
(submod_0): GraphModule(
|
| 119 |
+
(linear): Linear(in_features=4, out_features=5, bias=True)
|
| 120 |
+
)
|
| 121 |
+
(submod_1): GraphModule(
|
| 122 |
+
(linear): Linear(in_features=4, out_features=5, bias=True)
|
| 123 |
+
)
|
| 124 |
+
(submod_2): GraphModule()
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, y):
|
| 128 |
+
param = self.param
|
| 129 |
+
submod_0 = self.submod_0(x, param, y); x = param = y = None
|
| 130 |
+
getitem = submod_0[0]
|
| 131 |
+
getitem_1 = submod_0[1]; submod_0 = None
|
| 132 |
+
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
|
| 133 |
+
getitem_2 = submod_1[0]
|
| 134 |
+
getitem_3 = submod_1[1]; submod_1 = None
|
| 135 |
+
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
| 136 |
+
return submod_2
|
| 137 |
+
|
| 138 |
+
Output of split module is the same as output of input traced module.
|
| 139 |
+
This is an example within a test setting:
|
| 140 |
+
|
| 141 |
+
> orig_out = my_module_traced(x, y)
|
| 142 |
+
> submodules_out = module_with_submodules(x, y)
|
| 143 |
+
> self.assertEqual(orig_out, submodules_out)
|
| 144 |
+
True
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def construct_graph(
|
| 148 |
+
node: Node,
|
| 149 |
+
base_mod_env: Dict[str, Node],
|
| 150 |
+
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule],
|
| 151 |
+
):
|
| 152 |
+
if node.op == "placeholder":
|
| 153 |
+
default_value = (
|
| 154 |
+
node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
| 155 |
+
)
|
| 156 |
+
if keep_original_node_name:
|
| 157 |
+
args = () if default_value is inspect.Signature.empty else (default_value,)
|
| 158 |
+
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type)
|
| 159 |
+
else:
|
| 160 |
+
base_mod_env[node.name] = base_mod_graph.placeholder(
|
| 161 |
+
node.target, type_expr=node.type, default_value=default_value
|
| 162 |
+
)
|
| 163 |
+
base_mod_env[node.name].meta = node.meta.copy()
|
| 164 |
+
elif node.op == "get_attr":
|
| 165 |
+
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
|
| 166 |
+
base_mod_env[node.name].meta = node.meta.copy()
|
| 167 |
+
attr_val = m
|
| 168 |
+
for atom in node.target.split("."): # type: ignore[union-attr]
|
| 169 |
+
if not hasattr(attr_val, atom):
|
| 170 |
+
raise AttributeError(f"Node target {node.target} not found!")
|
| 171 |
+
attr_val = getattr(attr_val, atom)
|
| 172 |
+
base_mod_attrs[node.target] = attr_val # type: ignore[index]
|
| 173 |
+
return base_mod_env, base_mod_attrs
|
| 174 |
+
|
| 175 |
+
partitions: Dict[str, Partition] = {}
|
| 176 |
+
orig_nodes: Dict[str, Node] = {}
|
| 177 |
+
symbol_to_node: Dict["sympy.Symbol", Node] = {}
|
| 178 |
+
|
| 179 |
+
def record_cross_partition_use(
|
| 180 |
+
def_node: Node, use_node: Optional[Node]
|
| 181 |
+
): # noqa: B950
|
| 182 |
+
from torch.fx.experimental.symbolic_shapes import free_symbols
|
| 183 |
+
|
| 184 |
+
defined = getattr(def_node, "_fx_partition", None)
|
| 185 |
+
used = getattr(use_node, "_fx_partition", None)
|
| 186 |
+
if defined != used:
|
| 187 |
+
if defined is not None:
|
| 188 |
+
def_partition = partitions[defined]
|
| 189 |
+
def_partition.outputs.setdefault(def_node.name)
|
| 190 |
+
if used is not None:
|
| 191 |
+
def_partition.dependents.setdefault(used)
|
| 192 |
+
|
| 193 |
+
if used is not None:
|
| 194 |
+
use_partition = partitions[used]
|
| 195 |
+
use_partition.inputs.setdefault(def_node.name)
|
| 196 |
+
if (def_val := def_node.meta.get("example_value")) is not None:
|
| 197 |
+
for s in sorted(free_symbols(def_val), key=str):
|
| 198 |
+
use_partition.inputs.setdefault(symbol_to_node[s].name)
|
| 199 |
+
if defined is not None:
|
| 200 |
+
use_partition.dependencies.setdefault(defined)
|
| 201 |
+
|
| 202 |
+
def instantiate_node_partition_mapping(node):
|
| 203 |
+
partition_name = str(split_callback(node))
|
| 204 |
+
|
| 205 |
+
# add node to partitions
|
| 206 |
+
partition = partitions.get(partition_name)
|
| 207 |
+
if partition is None:
|
| 208 |
+
partitions[partition_name] = partition = Partition(partition_name)
|
| 209 |
+
|
| 210 |
+
partition.node_names.append(node.name)
|
| 211 |
+
node._fx_partition = partition_name
|
| 212 |
+
|
| 213 |
+
# Global State Nodes are nodes which by their global state effects,
|
| 214 |
+
# "taint" all downstream nodes while they are active.
|
| 215 |
+
GLOBAL_STATE_NODES = [
|
| 216 |
+
torch.amp._enter_autocast,
|
| 217 |
+
torch.amp._exit_autocast,
|
| 218 |
+
torch._C._set_grad_enabled
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
# For grad regions:
|
| 222 |
+
# ------------------------
|
| 223 |
+
# 1. first region: we do nothing
|
| 224 |
+
# 2. subsequent regions: we insert the set_grad at the beginning
|
| 225 |
+
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict()
|
| 226 |
+
|
| 227 |
+
# For autocast regions:
|
| 228 |
+
# ------------------------
|
| 229 |
+
# 1. first region: we will only insert the _exit at the end
|
| 230 |
+
# 2. intermediate regions: we will insert both the
|
| 231 |
+
# _enter at the beginning and _exit at the end
|
| 232 |
+
# 3. last region: we will only insert _enter at the beginning
|
| 233 |
+
# We will do so in the order in which the autocasts were instantiated.
|
| 234 |
+
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict()
|
| 235 |
+
autocast_exits: Dict[Node, Optional[Node]] = {}
|
| 236 |
+
|
| 237 |
+
active_grad = None
|
| 238 |
+
active_autocasts = set()
|
| 239 |
+
|
| 240 |
+
import sympy # noqa: F811
|
| 241 |
+
|
| 242 |
+
for node in m.graph.nodes:
|
| 243 |
+
if node.op in ["placeholder", "get_attr", "output"]:
|
| 244 |
+
if (
|
| 245 |
+
node.op == "placeholder" and
|
| 246 |
+
(val := node.meta.get("example_value")) is not None and
|
| 247 |
+
isinstance(val, torch.SymInt) and
|
| 248 |
+
isinstance(val.node.expr, sympy.Symbol)
|
| 249 |
+
):
|
| 250 |
+
symbol_to_node[val.node.expr] = node
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
instantiate_node_partition_mapping(node)
|
| 254 |
+
|
| 255 |
+
if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
|
| 256 |
+
if node.target == torch._C._set_grad_enabled:
|
| 257 |
+
assert len(node.args) == 1
|
| 258 |
+
assert isinstance(node.args[0], bool)
|
| 259 |
+
active_grad = node
|
| 260 |
+
grad_regions[active_grad] = set({split_callback(node)})
|
| 261 |
+
elif node.target == torch.amp._enter_autocast:
|
| 262 |
+
# Should all be python constants
|
| 263 |
+
assert all(not isinstance(arg, Node) for arg in node.args)
|
| 264 |
+
active_autocasts.add(node)
|
| 265 |
+
autocast_regions[node] = set({split_callback(node)})
|
| 266 |
+
autocast_exits[node] = None
|
| 267 |
+
elif node.target == torch.amp._exit_autocast:
|
| 268 |
+
assert len(node.args) == 1
|
| 269 |
+
autocast_regions[node.args[0]].add(split_callback(node))
|
| 270 |
+
active_autocasts.remove(node.args[0])
|
| 271 |
+
autocast_exits[node.args[0]] = node
|
| 272 |
+
|
| 273 |
+
if active_grad is not None:
|
| 274 |
+
grad_regions[active_grad].add(split_callback(node))
|
| 275 |
+
|
| 276 |
+
for a in active_autocasts:
|
| 277 |
+
autocast_regions[a].add(split_callback(node))
|
| 278 |
+
|
| 279 |
+
assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
|
| 280 |
+
|
| 281 |
+
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
|
| 282 |
+
grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
|
| 283 |
+
|
| 284 |
+
if _LOGGER.isEnabledFor(logging.DEBUG):
|
| 285 |
+
_LOGGER.debug("autocast_regions: %s", autocast_regions)
|
| 286 |
+
_LOGGER.debug("grad_regions: %s", grad_regions)
|
| 287 |
+
|
| 288 |
+
assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
|
| 289 |
+
|
| 290 |
+
# split nodes into partitions
|
| 291 |
+
highest_partition = -1
|
| 292 |
+
for node in m.graph.nodes:
|
| 293 |
+
orig_nodes[node.name] = node
|
| 294 |
+
|
| 295 |
+
# TODO currently placeholders/parameters aren't put into random partitions,
|
| 296 |
+
# rather they're added to the graphs where they are used down below
|
| 297 |
+
if node.op in ["placeholder", "get_attr"]:
|
| 298 |
+
continue
|
| 299 |
+
if node.op == "output":
|
| 300 |
+
torch.fx.graph.map_arg(
|
| 301 |
+
node.args[0], lambda n: record_cross_partition_use(n, None)
|
| 302 |
+
)
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
if assert_monotonically_increasing:
|
| 306 |
+
pid = split_callback(node)
|
| 307 |
+
assert highest_partition <= pid, \
|
| 308 |
+
("autocast or set_grad_enabled require monotonically increasing partitions:"
|
| 309 |
+
f"highest: {highest_partition}, this node's: {pid}")
|
| 310 |
+
highest_partition = pid
|
| 311 |
+
|
| 312 |
+
# do not capture cross-partition dependencies for global state nodes as they will be
|
| 313 |
+
# self-contained - their setup and unwind will be isolated to each partition submodule.
|
| 314 |
+
if node.target not in GLOBAL_STATE_NODES:
|
| 315 |
+
torch.fx.graph.map_arg(
|
| 316 |
+
node.args, lambda def_node: record_cross_partition_use(def_node, node)
|
| 317 |
+
)
|
| 318 |
+
torch.fx.graph.map_arg(
|
| 319 |
+
node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
|
| 320 |
+
) # noqa: B950
|
| 321 |
+
|
| 322 |
+
original_partition_order = list(partitions.keys())
|
| 323 |
+
# find partitions with no dependencies
|
| 324 |
+
root_partitions: List[str] = []
|
| 325 |
+
for partition_name, partition in partitions.items():
|
| 326 |
+
if not len(partition.dependencies):
|
| 327 |
+
root_partitions.append(partition_name)
|
| 328 |
+
|
| 329 |
+
# check partitions for circular dependencies and create topological partition ordering
|
| 330 |
+
sorted_partitions: List[str] = []
|
| 331 |
+
while root_partitions:
|
| 332 |
+
root_partition = root_partitions.pop()
|
| 333 |
+
sorted_partitions.append(root_partition)
|
| 334 |
+
for dependent in partitions[root_partition].dependents:
|
| 335 |
+
partitions[dependent].dependencies.pop(root_partition)
|
| 336 |
+
if not partitions[dependent].dependencies:
|
| 337 |
+
root_partitions.append(dependent)
|
| 338 |
+
if len(sorted_partitions) != len(partitions):
|
| 339 |
+
raise RuntimeError("cycle exists between partitions!")
|
| 340 |
+
|
| 341 |
+
# Enter prelude
|
| 342 |
+
for regions_mapping in [autocast_regions, grad_regions]:
|
| 343 |
+
for node, regions in regions_mapping.items():
|
| 344 |
+
assert len(regions) > 0
|
| 345 |
+
partitions[str(regions[0])].environment[node] = node
|
| 346 |
+
for r in regions[1:]:
|
| 347 |
+
partition = partitions[str(r)]
|
| 348 |
+
new_node = partition.graph.create_node(
|
| 349 |
+
op=node.op,
|
| 350 |
+
target=node.target,
|
| 351 |
+
args=tuple(arg for arg in node.args),
|
| 352 |
+
kwargs={},
|
| 353 |
+
type_expr=node.type,
|
| 354 |
+
)
|
| 355 |
+
new_node.meta = node.meta.copy() # is it really a good idea to copy this?
|
| 356 |
+
partition.environment[node] = new_node
|
| 357 |
+
|
| 358 |
+
# add placeholders to partition inputs
|
| 359 |
+
for partition_name in sorted_partitions:
|
| 360 |
+
partition = partitions[partition_name]
|
| 361 |
+
for inp in partition.inputs:
|
| 362 |
+
placeholder = partition.graph.placeholder(
|
| 363 |
+
inp,
|
| 364 |
+
type_expr=orig_nodes[inp].type,
|
| 365 |
+
)
|
| 366 |
+
placeholder.meta = orig_nodes[inp].meta.copy()
|
| 367 |
+
partition.environment[orig_nodes[inp]] = placeholder
|
| 368 |
+
|
| 369 |
+
# Transform nodes and collect targets for partition's submodule
|
| 370 |
+
for node in m.graph.nodes:
|
| 371 |
+
if hasattr(node, "_fx_partition"):
|
| 372 |
+
partition = partitions[node._fx_partition]
|
| 373 |
+
|
| 374 |
+
# swap out old graph nodes in kw/args with references to new nodes in this submodule
|
| 375 |
+
environment = partition.environment
|
| 376 |
+
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
|
| 377 |
+
gathered_kwargs = torch.fx.graph.map_arg(
|
| 378 |
+
node.kwargs, lambda n: environment[n]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if node.op not in ["call_module", "get_attr"]:
|
| 382 |
+
target = node.target
|
| 383 |
+
else:
|
| 384 |
+
target_atoms = node.target.split(".")
|
| 385 |
+
target_attr = m
|
| 386 |
+
for atom in target_atoms:
|
| 387 |
+
if not hasattr(target_attr, atom):
|
| 388 |
+
raise AttributeError(f"Operator target {node.target} not found!")
|
| 389 |
+
target_attr = getattr(target_attr, atom)
|
| 390 |
+
# target = target_atoms[-1]
|
| 391 |
+
target = "_".join(target_atoms)
|
| 392 |
+
partition.targets[target] = target_attr
|
| 393 |
+
# Fill in the passed-in mapping from new qualname to old qualname
|
| 394 |
+
if qualname_map is not None:
|
| 395 |
+
# When creating the split module later, the submodules will have
|
| 396 |
+
# path prefix matching the corresponding partition's submod_name
|
| 397 |
+
qualname = f"{partition.submod_name}.{target}"
|
| 398 |
+
qualname_map[qualname] = node.target
|
| 399 |
+
|
| 400 |
+
assert isinstance(gathered_args, tuple)
|
| 401 |
+
assert isinstance(gathered_kwargs, dict)
|
| 402 |
+
name = node.name if keep_original_node_name else None
|
| 403 |
+
new_node = partition.graph.create_node(
|
| 404 |
+
op=node.op,
|
| 405 |
+
target=target,
|
| 406 |
+
args=gathered_args,
|
| 407 |
+
kwargs=gathered_kwargs,
|
| 408 |
+
type_expr=node.type,
|
| 409 |
+
name=name,
|
| 410 |
+
)
|
| 411 |
+
new_node.meta = node.meta.copy()
|
| 412 |
+
partition.environment[node] = new_node
|
| 413 |
+
|
| 414 |
+
# Exit epilogue
|
| 415 |
+
for regions_mapping in [autocast_regions]:
|
| 416 |
+
for node in reversed(regions_mapping):
|
| 417 |
+
regions = regions_mapping[node]
|
| 418 |
+
assert len(regions) > 0
|
| 419 |
+
for r in regions[:-1]:
|
| 420 |
+
partition = partitions[str(r)]
|
| 421 |
+
exit_node = autocast_exits[node]
|
| 422 |
+
assert exit_node is not None, "Missing exit node"
|
| 423 |
+
new_node = partition.graph.create_node(
|
| 424 |
+
op=exit_node.op,
|
| 425 |
+
target=exit_node.target,
|
| 426 |
+
args=(partition.environment[node],),
|
| 427 |
+
kwargs={},
|
| 428 |
+
type_expr=exit_node.type,
|
| 429 |
+
)
|
| 430 |
+
new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this?
|
| 431 |
+
|
| 432 |
+
# original module environment dict mapping node names to nodes
|
| 433 |
+
orig_mod_env: Dict[str, Node] = {}
|
| 434 |
+
# Set up values to construct base module
|
| 435 |
+
base_mod_env: Dict[str, Node] = {}
|
| 436 |
+
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
|
| 437 |
+
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
| 438 |
+
if not keep_original_order:
|
| 439 |
+
for node in m.graph.nodes:
|
| 440 |
+
base_mod_env, base_mod_attrs = construct_graph(
|
| 441 |
+
node, base_mod_env, base_mod_attrs
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
else:
|
| 445 |
+
# Go through the graph to construct the mapping dict
|
| 446 |
+
for node in m.graph.nodes:
|
| 447 |
+
orig_mod_env[node.name] = node
|
| 448 |
+
|
| 449 |
+
# Do some things iterating over the partitions in topological order again:
|
| 450 |
+
# 1) Finish off submodule Graphs by setting corresponding outputs
|
| 451 |
+
# 2) Construct GraphModules for each submodule
|
| 452 |
+
# 3) Construct the base graph by emitting calls to those submodules in
|
| 453 |
+
# topological order or original order specified by keep_original_order
|
| 454 |
+
|
| 455 |
+
construct_order_partitions = (
|
| 456 |
+
sorted_partitions if not keep_original_order else original_partition_order
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
already_constructed_attr_nodes = set()
|
| 460 |
+
for partition_name in construct_order_partitions:
|
| 461 |
+
partition = partitions[partition_name]
|
| 462 |
+
|
| 463 |
+
# Set correct output values
|
| 464 |
+
output_vals = tuple(
|
| 465 |
+
partition.environment[orig_nodes[name]] for name in partition.outputs
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# skip output node generation if there are no output values
|
| 469 |
+
num_output_vals = len(output_vals)
|
| 470 |
+
if num_output_vals == 1:
|
| 471 |
+
partition.graph.output(output_vals[0])
|
| 472 |
+
elif num_output_vals > 1:
|
| 473 |
+
partition.graph.output(output_vals)
|
| 474 |
+
|
| 475 |
+
if keep_original_order:
|
| 476 |
+
# first get the attr nodes required by this partition
|
| 477 |
+
orig_mod_attr_nodes: List[Node] = [
|
| 478 |
+
orig_mod_env[key] for key in partition.inputs
|
| 479 |
+
]
|
| 480 |
+
# Construct GraphModule for this partition
|
| 481 |
+
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
|
| 482 |
+
if node in already_constructed_attr_nodes:
|
| 483 |
+
continue
|
| 484 |
+
base_mod_env, base_mod_attrs = construct_graph(
|
| 485 |
+
node, base_mod_env, base_mod_attrs
|
| 486 |
+
)
|
| 487 |
+
already_constructed_attr_nodes.add(node)
|
| 488 |
+
|
| 489 |
+
base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
|
| 490 |
+
partition.targets, partition.graph
|
| 491 |
+
) # noqa: B950
|
| 492 |
+
|
| 493 |
+
# Emit call in base graph to this submodule
|
| 494 |
+
output_val = base_mod_graph.call_module(
|
| 495 |
+
partition.submod_name,
|
| 496 |
+
tuple(base_mod_env[name] for name in partition.inputs),
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
num_outputs = len(partition.outputs)
|
| 500 |
+
if num_outputs > 1:
|
| 501 |
+
# Unpack multiple return values from submodule
|
| 502 |
+
output_val_proxy = torch.fx.proxy.Proxy(output_val)
|
| 503 |
+
for i, output_name in enumerate(partition.outputs):
|
| 504 |
+
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
|
| 505 |
+
elif num_outputs == 1:
|
| 506 |
+
base_mod_env[next(iter(partition.outputs))] = output_val
|
| 507 |
+
|
| 508 |
+
for node in m.graph.nodes:
|
| 509 |
+
if node.op == "output":
|
| 510 |
+
base_mod_graph.output(
|
| 511 |
+
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
|
| 512 |
+
) # noqa: B950
|
| 513 |
+
|
| 514 |
+
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.fx.passes.graph_manipulation import get_size_of_node
|
| 10 |
+
from torch.fx.node import map_arg
|
| 11 |
+
from torch.fx._compatibility import compatibility
|
| 12 |
+
|
| 13 |
+
from .operator_support import (
|
| 14 |
+
get_node_target,
|
| 15 |
+
OperatorSupportBase,
|
| 16 |
+
)
|
| 17 |
+
from .graph_drawer import FxGraphDrawer
|
| 18 |
+
from .shape_prop import ShapeProp
|
| 19 |
+
from .split_utils import split_by_tags
|
| 20 |
+
from .tools_common import (
|
| 21 |
+
FxNetAccFusionsFinder,
|
| 22 |
+
CALLABLE_NODE_OPS,
|
| 23 |
+
Tensors,
|
| 24 |
+
NodeList,
|
| 25 |
+
NodeSet,
|
| 26 |
+
is_node_output_tensor,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
|
| 31 |
+
_LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
DEFAULT_MIN_ACC_MODULE_SIZE = 1
|
| 34 |
+
DEFAULT_SKIP_FUSION = False
|
| 35 |
+
DEFAULT_ALLOW_NON_TENSOR = False
|
| 36 |
+
|
| 37 |
+
class _SplitterSettingBase:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
|
| 41 |
+
skip_fusion=DEFAULT_SKIP_FUSION,
|
| 42 |
+
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
|
| 43 |
+
):
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--min-acc-module-size",
|
| 47 |
+
"--min_acc_module_size",
|
| 48 |
+
required=False,
|
| 49 |
+
type=int,
|
| 50 |
+
help="Minimum size limit of an accelerator subgraph.",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--skip-fusion",
|
| 54 |
+
"--skip_fusion",
|
| 55 |
+
default=False,
|
| 56 |
+
action="store_true",
|
| 57 |
+
help="If true then no fusion groups. Fusion group is used to "
|
| 58 |
+
"enforce no non-tensor data flow between submodules. If we don't "
|
| 59 |
+
"have this constrain, setting this to false is recommended as it "
|
| 60 |
+
"can reduce overhead.",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--allow-non-tensor",
|
| 64 |
+
"--allow_non_tensor",
|
| 65 |
+
default=False,
|
| 66 |
+
action="store_true",
|
| 67 |
+
help="For some backends non-tensor data flow between cpu and them "
|
| 68 |
+
"are not allowed. Therefore, if a node supported by accelerator but "
|
| 69 |
+
"it has non-tensor inputs or outputs to a cpu node we would want to "
|
| 70 |
+
"consider it as a cpu node during splitting. However, for some backends "
|
| 71 |
+
"we might not care about non-tensor data flow and we can set this option "
|
| 72 |
+
"to true to disable the functionality that prevent non-tensor data flow.",
|
| 73 |
+
)
|
| 74 |
+
args, unknown = parser.parse_known_args()
|
| 75 |
+
|
| 76 |
+
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
|
| 77 |
+
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
|
| 78 |
+
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@compatibility(is_backward_compatible=False)
|
| 82 |
+
class FxNetAccNodesFinder:
|
| 83 |
+
"""
|
| 84 |
+
Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
|
| 85 |
+
input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
|
| 86 |
+
|
| 87 |
+
I.e. if we have a chain:
|
| 88 |
+
|
| 89 |
+
ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
|
| 90 |
+
|
| 91 |
+
where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
|
| 92 |
+
|
| 93 |
+
This behavior can be turned off by passing allow_non_tensor=True.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
module: torch.fx.GraphModule,
|
| 99 |
+
operator_support: OperatorSupportBase,
|
| 100 |
+
allow_non_tensor: bool,
|
| 101 |
+
):
|
| 102 |
+
self.module = module
|
| 103 |
+
self.operator_support = operator_support
|
| 104 |
+
self.allow_non_tensor = allow_non_tensor
|
| 105 |
+
|
| 106 |
+
def reduce_acc_nodes_non_tensor_input_helper(
|
| 107 |
+
self, cpu_worklist: NodeList
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Transitively excludes nodes from ACC supported set.
|
| 111 |
+
For every node in the worklist:
|
| 112 |
+
- removes its downstream ACC nodes from ACC supported set,
|
| 113 |
+
- if any downstream ACC node produces non-tensor output,
|
| 114 |
+
then it gets added into the worklist.
|
| 115 |
+
"""
|
| 116 |
+
while cpu_worklist:
|
| 117 |
+
node = cpu_worklist.pop(0)
|
| 118 |
+
|
| 119 |
+
for user in node.users:
|
| 120 |
+
if user in self.acc_nodes:
|
| 121 |
+
self.acc_nodes.remove(user)
|
| 122 |
+
if not is_node_output_tensor(user):
|
| 123 |
+
cpu_worklist.append(user)
|
| 124 |
+
|
| 125 |
+
def reduce_acc_nodes_non_tensor_input(self):
|
| 126 |
+
"""
|
| 127 |
+
Excludes nodes from ACC supported set that have direct
|
| 128 |
+
upstream CPU nodes that produce non-tensor outputs.
|
| 129 |
+
"""
|
| 130 |
+
non_tensor_cpu_nodes: NodeList = []
|
| 131 |
+
|
| 132 |
+
for node in self.module.graph.nodes:
|
| 133 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 134 |
+
continue
|
| 135 |
+
if node in self.acc_nodes:
|
| 136 |
+
continue
|
| 137 |
+
if is_node_output_tensor(node):
|
| 138 |
+
continue
|
| 139 |
+
non_tensor_cpu_nodes.append(node)
|
| 140 |
+
|
| 141 |
+
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
|
| 142 |
+
|
| 143 |
+
def reduce_acc_nodes_non_tensor_output(self):
|
| 144 |
+
"""
|
| 145 |
+
Excludes nodes from ACC supported set that produce non-tensor
|
| 146 |
+
outputs and have downstream CPU nodes.
|
| 147 |
+
"""
|
| 148 |
+
while True:
|
| 149 |
+
new_cpu_nodes: NodeList = []
|
| 150 |
+
|
| 151 |
+
for acc_node in self.acc_nodes:
|
| 152 |
+
if is_node_output_tensor(acc_node):
|
| 153 |
+
continue
|
| 154 |
+
for user in acc_node.users:
|
| 155 |
+
if user not in self.acc_nodes:
|
| 156 |
+
new_cpu_nodes.append(acc_node)
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
if not new_cpu_nodes:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
for new_cpu_node in new_cpu_nodes:
|
| 163 |
+
self.acc_nodes.remove(new_cpu_node)
|
| 164 |
+
|
| 165 |
+
self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
|
| 166 |
+
|
| 167 |
+
def __call__(self) -> NodeSet:
|
| 168 |
+
submodules = dict(self.module.named_modules())
|
| 169 |
+
self.acc_nodes = {
|
| 170 |
+
n
|
| 171 |
+
for n in self.module.graph.nodes
|
| 172 |
+
if n.op in CALLABLE_NODE_OPS
|
| 173 |
+
and self.operator_support.is_node_supported(submodules, n)
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
if not self.allow_non_tensor:
|
| 177 |
+
self.reduce_acc_nodes_non_tensor_input()
|
| 178 |
+
self.reduce_acc_nodes_non_tensor_output()
|
| 179 |
+
|
| 180 |
+
return self.acc_nodes
|
| 181 |
+
|
| 182 |
+
@compatibility(is_backward_compatible=False)
|
| 183 |
+
class FxNetSplitterInternalError(Exception):
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
@compatibility(is_backward_compatible=False)
|
| 187 |
+
@dataclass
|
| 188 |
+
class Subgraph:
|
| 189 |
+
is_acc: bool
|
| 190 |
+
nodes: NodeList
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@compatibility(is_backward_compatible=False)
|
| 194 |
+
class SplitResult(NamedTuple):
|
| 195 |
+
"""
|
| 196 |
+
Stores the results of the splitter.
|
| 197 |
+
|
| 198 |
+
Attributes:
|
| 199 |
+
split_module: root module after splitting.
|
| 200 |
+
submodule_inputs: a dict that maps submodule name to its inputs.
|
| 201 |
+
non_acc_submodule_prefix: the prefix for non acc submodules. For
|
| 202 |
+
acc submodule the prefix is alwasy "_run_on_acc_".
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
split_module: torch.fx.GraphModule
|
| 206 |
+
submodule_inputs: Dict[str, Any]
|
| 207 |
+
non_acc_submodule_prefix: str
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@compatibility(is_backward_compatible=False)
|
| 211 |
+
def generate_inputs_for_submodules(
|
| 212 |
+
model: torch.nn.Module,
|
| 213 |
+
inputs: Sequence[Any],
|
| 214 |
+
target_submodules: Iterable[str],
|
| 215 |
+
deepcopy: bool = False,
|
| 216 |
+
) -> Dict[str, Any]:
|
| 217 |
+
"""
|
| 218 |
+
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
|
| 219 |
+
function doesn't work.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
model: root model.
|
| 223 |
+
inputs: inputs to the root model.
|
| 224 |
+
target_submodules: submodules that we want to generate inputs for.
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
A dict that maps from submodule name to its inputs.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
handles = []
|
| 231 |
+
results = {}
|
| 232 |
+
submodule_to_names = {mod: name for name, mod in model.named_modules()}
|
| 233 |
+
|
| 234 |
+
def pre_forward(module, module_inputs):
|
| 235 |
+
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
|
| 236 |
+
|
| 237 |
+
for name, mod in model.named_modules():
|
| 238 |
+
if name in target_submodules:
|
| 239 |
+
handles.append(mod.register_forward_pre_hook(pre_forward))
|
| 240 |
+
|
| 241 |
+
def clean_up_handles():
|
| 242 |
+
for h in handles:
|
| 243 |
+
h.remove()
|
| 244 |
+
|
| 245 |
+
try:
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
model(*inputs)
|
| 248 |
+
except Exception as e:
|
| 249 |
+
clean_up_handles()
|
| 250 |
+
raise e
|
| 251 |
+
|
| 252 |
+
clean_up_handles()
|
| 253 |
+
return results
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class _SplitterBase:
|
| 257 |
+
"""
|
| 258 |
+
Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
|
| 259 |
+
Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
|
| 260 |
+
Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
|
| 261 |
+
|
| 262 |
+
Given the following graph:
|
| 263 |
+
==> b ==>
|
| 264 |
+
// \\
|
| 265 |
+
a d
|
| 266 |
+
\\ //
|
| 267 |
+
==> c ==>
|
| 268 |
+
|
| 269 |
+
class SimpleModule(torch.nn.Module):
|
| 270 |
+
def forward(self, a):
|
| 271 |
+
b = torch.sin(a)
|
| 272 |
+
c = torch.cos(a)
|
| 273 |
+
d = b + c
|
| 274 |
+
return d
|
| 275 |
+
|
| 276 |
+
and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
|
| 277 |
+
we will get the following split result:
|
| 278 |
+
|
| 279 |
+
main:
|
| 280 |
+
def forward(self, a):
|
| 281 |
+
run_on_acc_0_0 = self._run_on_acc_0_0(a)
|
| 282 |
+
getitem = run_on_acc_0_0[0]
|
| 283 |
+
getitem_1 = run_on_acc_0_0[1]
|
| 284 |
+
run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
|
| 285 |
+
return run_on_cpu_1_1
|
| 286 |
+
|
| 287 |
+
_run_on_acc_0_0:
|
| 288 |
+
def forward(self, a):
|
| 289 |
+
sin_1 = torch.sin(a)
|
| 290 |
+
cos_1 = torch.cos(a)
|
| 291 |
+
return (sin_1, cos_1)
|
| 292 |
+
|
| 293 |
+
_run_on_cpu_1_1:
|
| 294 |
+
def forward(self, sin_1, cos_1):
|
| 295 |
+
add_1 = sin_1 + cos_1
|
| 296 |
+
return add_1
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
# PCIe bandwidth for the backend, default to 100 GB/s
|
| 300 |
+
PCIe_BW = 100 * 2 ** 30
|
| 301 |
+
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
module: torch.fx.GraphModule,
|
| 305 |
+
sample_input: Sequence[Any],
|
| 306 |
+
operator_support: OperatorSupportBase,
|
| 307 |
+
settings: _SplitterSettingBase,
|
| 308 |
+
non_acc_submodule_name: str = "_run_on_cpu_",
|
| 309 |
+
):
|
| 310 |
+
"""
|
| 311 |
+
Preprocesses graph before splitting:
|
| 312 |
+
- finds nodes supported by ACC,
|
| 313 |
+
- finds fusion groups for ACC nodes having non-tensor IO,
|
| 314 |
+
- builds a graph of direct dependencies,
|
| 315 |
+
- builds a map of fused nodes to their fusions.
|
| 316 |
+
As a result we get self.acc_nodes, self.deps and self.fusions.
|
| 317 |
+
"""
|
| 318 |
+
assert isinstance(module, torch.fx.GraphModule)
|
| 319 |
+
|
| 320 |
+
self.module = module
|
| 321 |
+
ShapeProp(self.module).propagate(*sample_input)
|
| 322 |
+
|
| 323 |
+
self.settings = settings
|
| 324 |
+
self.operator_support = operator_support
|
| 325 |
+
self.sample_input = sample_input
|
| 326 |
+
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
|
| 327 |
+
|
| 328 |
+
if self.settings.skip_fusion:
|
| 329 |
+
self.fusions = {}
|
| 330 |
+
else:
|
| 331 |
+
self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
|
| 332 |
+
|
| 333 |
+
# Modify deps to add more deps for fused nodes
|
| 334 |
+
self.deps = self.find_deps()
|
| 335 |
+
self.update_deps_for_fusions()
|
| 336 |
+
|
| 337 |
+
self.non_acc_submodule_name = non_acc_submodule_name
|
| 338 |
+
self._node_submodule_map: Dict[str, str] = {}
|
| 339 |
+
|
| 340 |
+
# ===============================================================
|
| 341 |
+
# Helpers for ctor and initial state
|
| 342 |
+
# ===============================================================
|
| 343 |
+
|
| 344 |
+
def get_node_submodule_map(self) -> Dict[str, str]:
|
| 345 |
+
""" Returns a map from node name to submodule name, e.g.
|
| 346 |
+
node: main_module_impl_impl_over_arch_unary_multiple_embedding
|
| 347 |
+
_pooling_embedding_pooling_sparse_entity_equivalence_key
|
| 348 |
+
_proxy_embedding_bag
|
| 349 |
+
maps to submodule name of: _run_on_acc_1
|
| 350 |
+
"""
|
| 351 |
+
return self._node_submodule_map
|
| 352 |
+
|
| 353 |
+
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
|
| 354 |
+
"""
|
| 355 |
+
Builds a graph of node dependencies. Leaf nodes don't have any
|
| 356 |
+
dependencies and the "output" node doesn't have nodes depending on it.
|
| 357 |
+
|
| 358 |
+
Resulting graph has only direct dependencies, i.e. there are no
|
| 359 |
+
transitive dependencies.
|
| 360 |
+
"""
|
| 361 |
+
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
| 362 |
+
for node in self.module.graph.nodes:
|
| 363 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
for user in node.users:
|
| 367 |
+
if user.op != "output":
|
| 368 |
+
deps[user].add(node)
|
| 369 |
+
return deps
|
| 370 |
+
|
| 371 |
+
def update_deps_for_fusions(self):
|
| 372 |
+
"""
|
| 373 |
+
Updates graph of dependencies so that:
|
| 374 |
+
- nodes from the same fusion depend on the same set of outer nodes,
|
| 375 |
+
- outer nodes depending on a fusion depend on all nodes in that fusion.
|
| 376 |
+
"""
|
| 377 |
+
for node in self.fusions:
|
| 378 |
+
fusion = self.fusions[node]
|
| 379 |
+
for fused_neighbor in fusion:
|
| 380 |
+
self.deps[node].update(self.deps[fused_neighbor] - fusion)
|
| 381 |
+
|
| 382 |
+
for user in fused_neighbor.users:
|
| 383 |
+
if user not in fusion:
|
| 384 |
+
self.deps[user].add(node)
|
| 385 |
+
|
| 386 |
+
# ===============================================================
|
| 387 |
+
# Helpers for preview
|
| 388 |
+
# ===============================================================
|
| 389 |
+
|
| 390 |
+
def _lower_model_to_backend(
|
| 391 |
+
self, mod: torch.fx.GraphModule, inputs: Tensors
|
| 392 |
+
) -> torch.nn.Module:
|
| 393 |
+
"""
|
| 394 |
+
Lower the model to a backend.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
return mod
|
| 398 |
+
|
| 399 |
+
def _find_culprit(
|
| 400 |
+
self, mod: torch.fx.GraphModule, inputs: Tensors
|
| 401 |
+
) -> str:
|
| 402 |
+
"""
|
| 403 |
+
When an error occurs during lowering or running the lowered mod, we use this
|
| 404 |
+
function to find culprits in the `mod` that causes the error.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
return "Unable to find a culprit because _find_culprit() function is not implemented."
|
| 408 |
+
|
| 409 |
+
def _draw_graph_based_on_node_support(
|
| 410 |
+
self, mod: torch.fx.GraphModule, supported_nodes: NodeList
|
| 411 |
+
):
|
| 412 |
+
color_map = {
|
| 413 |
+
"default": "AliceBlue",
|
| 414 |
+
"supported": "chartreuse1",
|
| 415 |
+
"unsupported": "crimson",
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
class CustomDrawer(FxGraphDrawer):
|
| 419 |
+
def _get_node_style(self, node):
|
| 420 |
+
template = super()._get_node_style(node)
|
| 421 |
+
if node in supported_nodes:
|
| 422 |
+
template["fillcolor"] = color_map["supported"]
|
| 423 |
+
elif node.op in CALLABLE_NODE_OPS:
|
| 424 |
+
template["fillcolor"] = color_map["unsupported"]
|
| 425 |
+
else:
|
| 426 |
+
template["fillcolor"] = color_map["default"]
|
| 427 |
+
|
| 428 |
+
return template
|
| 429 |
+
|
| 430 |
+
drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
|
| 431 |
+
dot_graph = drawer.get_main_dot_graph()
|
| 432 |
+
dot_graph.write_raw("node_support.dot")
|
| 433 |
+
|
| 434 |
+
def node_support_preview(self, dump_graph: bool = False):
|
| 435 |
+
submodules = dict(self.module.named_modules())
|
| 436 |
+
|
| 437 |
+
supported_nodes: NodeList = []
|
| 438 |
+
supported_node_types = defaultdict(set)
|
| 439 |
+
unsupported_node_types = defaultdict(set)
|
| 440 |
+
|
| 441 |
+
def get_dtype(arg):
|
| 442 |
+
tensor_meta = arg.meta.get("tensor_meta")
|
| 443 |
+
return getattr(tensor_meta, "dtype", None)
|
| 444 |
+
|
| 445 |
+
for node in self.module.graph.nodes:
|
| 446 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
target = get_node_target(submodules, node)
|
| 450 |
+
|
| 451 |
+
# Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
|
| 452 |
+
arg_dtypes = [
|
| 453 |
+
get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
|
| 454 |
+
for arg in node.args
|
| 455 |
+
]
|
| 456 |
+
|
| 457 |
+
# Find last non-None element. If all elements are None, return max_len.
|
| 458 |
+
last_index = len(arg_dtypes) - next(
|
| 459 |
+
(
|
| 460 |
+
i
|
| 461 |
+
for i, dtype in enumerate(reversed(arg_dtypes))
|
| 462 |
+
if dtype is not None
|
| 463 |
+
),
|
| 464 |
+
len(arg_dtypes),
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Strip None elements at the end.
|
| 468 |
+
arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
|
| 469 |
+
kwarg_dtypes_tuple = tuple(
|
| 470 |
+
(k, get_dtype(arg))
|
| 471 |
+
for k, arg in node.kwargs.items()
|
| 472 |
+
if isinstance(arg, torch.fx.Node)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
if self.operator_support.is_node_supported(submodules, node):
|
| 476 |
+
supported_nodes.append(node)
|
| 477 |
+
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
|
| 478 |
+
else:
|
| 479 |
+
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
|
| 480 |
+
|
| 481 |
+
if dump_graph:
|
| 482 |
+
self._draw_graph_based_on_node_support(self.module, supported_nodes)
|
| 483 |
+
|
| 484 |
+
reports = "\nSupported node types in the model:\n"
|
| 485 |
+
for t, dtypes in supported_node_types.items():
|
| 486 |
+
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
|
| 487 |
+
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
|
| 488 |
+
|
| 489 |
+
reports += "\nUnsupported node types in the model:\n"
|
| 490 |
+
for t, dtypes in unsupported_node_types.items():
|
| 491 |
+
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
|
| 492 |
+
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
|
| 493 |
+
|
| 494 |
+
print(reports)
|
| 495 |
+
|
| 496 |
+
# Return reports for testing purpose
|
| 497 |
+
return reports
|
| 498 |
+
|
| 499 |
+
def split_preview(self, dump_graph: bool = False):
|
| 500 |
+
reports = ""
|
| 501 |
+
subgraphs = self.put_nodes_into_subgraphs()
|
| 502 |
+
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
|
| 503 |
+
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
|
| 504 |
+
reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
|
| 505 |
+
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
|
| 506 |
+
|
| 507 |
+
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
|
| 508 |
+
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
|
| 509 |
+
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
|
| 510 |
+
reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
|
| 511 |
+
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
|
| 512 |
+
|
| 513 |
+
for i, subgraph in enumerate(subgraphs):
|
| 514 |
+
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
|
| 515 |
+
reports += f"{len(subgraph.nodes)} node(s)\n"
|
| 516 |
+
|
| 517 |
+
self.tag(subgraphs)
|
| 518 |
+
split_mod = self.split(remove_tag=True)
|
| 519 |
+
split_mod.eval()
|
| 520 |
+
|
| 521 |
+
if dump_graph:
|
| 522 |
+
drawer = FxGraphDrawer(
|
| 523 |
+
split_mod, "preview", ignore_getattr=True
|
| 524 |
+
)
|
| 525 |
+
dot_graphs = drawer.get_all_dot_graphs()
|
| 526 |
+
for name, dot_graph in dot_graphs.items():
|
| 527 |
+
dot_graph.write_raw(f"{name}.dot")
|
| 528 |
+
|
| 529 |
+
max_qps: float = self.PCIe_BW
|
| 530 |
+
bottleneck_module = ""
|
| 531 |
+
|
| 532 |
+
for node in split_mod.graph.nodes:
|
| 533 |
+
if node.op == "call_module" and "acc" in node.target:
|
| 534 |
+
reports += f"\nProcessing acc submodule {node.target}\n"
|
| 535 |
+
|
| 536 |
+
submod = getattr(split_mod, node.target)
|
| 537 |
+
|
| 538 |
+
def get_submod_inputs(main_mod, submod, example_inputs):
|
| 539 |
+
sub_inputs = None
|
| 540 |
+
|
| 541 |
+
def get_inputs(self, inputs):
|
| 542 |
+
nonlocal sub_inputs
|
| 543 |
+
sub_inputs = inputs
|
| 544 |
+
|
| 545 |
+
handle = submod.register_forward_pre_hook(get_inputs)
|
| 546 |
+
main_mod(*example_inputs)
|
| 547 |
+
handle.remove()
|
| 548 |
+
return sub_inputs
|
| 549 |
+
|
| 550 |
+
submod_inputs = get_submod_inputs(
|
| 551 |
+
split_mod, submod, self.sample_input
|
| 552 |
+
)
|
| 553 |
+
ShapeProp(submod).propagate(*submod_inputs)
|
| 554 |
+
|
| 555 |
+
total_input_bytes = 0
|
| 556 |
+
total_output_bytes = 0
|
| 557 |
+
|
| 558 |
+
reports += "Checking inputs...\n"
|
| 559 |
+
for n in submod.graph.nodes:
|
| 560 |
+
if n.op == "placeholder":
|
| 561 |
+
if not is_node_output_tensor(n):
|
| 562 |
+
reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
|
| 563 |
+
else:
|
| 564 |
+
total_input_bytes += get_size_of_node(submod, n)[0]
|
| 565 |
+
if n.op == "output":
|
| 566 |
+
output_node = n
|
| 567 |
+
|
| 568 |
+
reports += "Checking outputs...\n"
|
| 569 |
+
|
| 570 |
+
def get_bytes(node: torch.fx.Node):
|
| 571 |
+
nonlocal total_output_bytes
|
| 572 |
+
nonlocal reports
|
| 573 |
+
if not is_node_output_tensor(node):
|
| 574 |
+
reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
|
| 575 |
+
else:
|
| 576 |
+
total_output_bytes += get_size_of_node(submod, node)[0]
|
| 577 |
+
|
| 578 |
+
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
|
| 579 |
+
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
|
| 580 |
+
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
|
| 581 |
+
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
|
| 582 |
+
|
| 583 |
+
if qps < max_qps:
|
| 584 |
+
max_qps = qps
|
| 585 |
+
bottleneck_module = node.target
|
| 586 |
+
|
| 587 |
+
try:
|
| 588 |
+
lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
|
| 589 |
+
except RuntimeError:
|
| 590 |
+
reports += "Run into an error during lowering!\n"
|
| 591 |
+
reports += self._find_culprit(submod, submod_inputs)
|
| 592 |
+
continue
|
| 593 |
+
|
| 594 |
+
try:
|
| 595 |
+
lowered_submod(*submod_inputs)
|
| 596 |
+
except RuntimeError:
|
| 597 |
+
reports += "Run into an error during inference!\n"
|
| 598 |
+
reports += self._find_culprit(submod, submod_inputs)
|
| 599 |
+
else:
|
| 600 |
+
reports += "Lowering and running succeed!\n"
|
| 601 |
+
|
| 602 |
+
reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
|
| 603 |
+
reports += f" bottleneck is submodule {bottleneck_module}."
|
| 604 |
+
print(reports)
|
| 605 |
+
|
| 606 |
+
# return the reports for testing purposes
|
| 607 |
+
return reports
|
| 608 |
+
|
| 609 |
+
# ===============================================================
|
| 610 |
+
# Helpers for extend_acc_subgraph() method
|
| 611 |
+
# ===============================================================
|
| 612 |
+
|
| 613 |
+
def find_reverse_deps(
|
| 614 |
+
self, tag_id: Optional[int] = None
|
| 615 |
+
) -> Dict[torch.fx.Node, NodeSet]:
|
| 616 |
+
"""
|
| 617 |
+
Builds reversed topological node dependencies, if tag_id is specified,
|
| 618 |
+
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
|
| 619 |
+
"""
|
| 620 |
+
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
|
| 621 |
+
|
| 622 |
+
for node in self.module.graph.nodes:
|
| 623 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 624 |
+
continue
|
| 625 |
+
|
| 626 |
+
for user in node.users:
|
| 627 |
+
if user.op not in CALLABLE_NODE_OPS:
|
| 628 |
+
continue
|
| 629 |
+
|
| 630 |
+
if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
|
| 631 |
+
result[node].add(user)
|
| 632 |
+
|
| 633 |
+
return result
|
| 634 |
+
|
| 635 |
+
def update_reverse_deps_for_fusions(
|
| 636 |
+
self, deps: Dict[torch.fx.Node, NodeSet]
|
| 637 |
+
):
|
| 638 |
+
processed_node = set()
|
| 639 |
+
|
| 640 |
+
for node, fusion in self.fusions.items():
|
| 641 |
+
if node in processed_node:
|
| 642 |
+
continue
|
| 643 |
+
|
| 644 |
+
new_dep = set()
|
| 645 |
+
|
| 646 |
+
# Create a new dependency set which include all the
|
| 647 |
+
# dependencies of the nodes in the fusion group
|
| 648 |
+
for n in fusion:
|
| 649 |
+
new_dep.update(deps[n])
|
| 650 |
+
|
| 651 |
+
# Exclude nodes in the fusion
|
| 652 |
+
new_dep.difference_update(fusion)
|
| 653 |
+
|
| 654 |
+
# Update dependency
|
| 655 |
+
for n in fusion:
|
| 656 |
+
deps[n] = new_dep
|
| 657 |
+
|
| 658 |
+
for arg in n.all_input_nodes:
|
| 659 |
+
if arg not in fusion:
|
| 660 |
+
deps[arg].update(fusion)
|
| 661 |
+
|
| 662 |
+
processed_node.add(n)
|
| 663 |
+
|
| 664 |
+
def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
|
| 665 |
+
"""
|
| 666 |
+
Finds parent nodes of the `tag` subgraph.
|
| 667 |
+
|
| 668 |
+
Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
|
| 669 |
+
and is not a placeholder, we consider it as the parent node of the subgraph.
|
| 670 |
+
"""
|
| 671 |
+
parent_nodes = set()
|
| 672 |
+
|
| 673 |
+
for node in self.module.graph.nodes:
|
| 674 |
+
if node.op in CALLABLE_NODE_OPS and node.tag == tag:
|
| 675 |
+
for arg in node.all_input_nodes:
|
| 676 |
+
if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
|
| 677 |
+
parent_nodes.add(arg)
|
| 678 |
+
|
| 679 |
+
return parent_nodes
|
| 680 |
+
|
| 681 |
+
def extend_acc_subgraph(self, tag: str):
|
| 682 |
+
"""
|
| 683 |
+
Extend the acc subgraph with `tag` going the reversed topological direction.
|
| 684 |
+
"""
|
| 685 |
+
# Dict that maps node to its users and ignore users that
|
| 686 |
+
# are in the subgraph that has greater tag
|
| 687 |
+
deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
|
| 688 |
+
self.update_reverse_deps_for_fusions(deps)
|
| 689 |
+
|
| 690 |
+
# Parent nodes of the subgraph
|
| 691 |
+
parent_nodes = self.find_parent_nodes_of_subgraph(tag)
|
| 692 |
+
|
| 693 |
+
visited_nodes: NodeSet = set()
|
| 694 |
+
|
| 695 |
+
while parent_nodes:
|
| 696 |
+
node = None
|
| 697 |
+
|
| 698 |
+
# Find a acc node that depends on visited nodes only
|
| 699 |
+
for n in parent_nodes:
|
| 700 |
+
if deps[n] <= visited_nodes and n in self.acc_nodes:
|
| 701 |
+
node = n
|
| 702 |
+
break
|
| 703 |
+
|
| 704 |
+
if node is None:
|
| 705 |
+
break
|
| 706 |
+
|
| 707 |
+
# Put the node into `tag` subgraph
|
| 708 |
+
node.tag = tag # type: ignore[attr-defined]
|
| 709 |
+
parent_nodes.remove(node)
|
| 710 |
+
visited_nodes.add(node)
|
| 711 |
+
|
| 712 |
+
# If node is in a fusion group, add all fusion buddies to parent nodes
|
| 713 |
+
if node in self.fusions:
|
| 714 |
+
for fusion_node in self.fusions[node]:
|
| 715 |
+
if fusion_node not in visited_nodes:
|
| 716 |
+
parent_nodes.add(fusion_node)
|
| 717 |
+
|
| 718 |
+
# Add inputs of the node to parent nodes
|
| 719 |
+
for arg in node.all_input_nodes:
|
| 720 |
+
if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
|
| 721 |
+
parent_nodes.add(arg)
|
| 722 |
+
|
| 723 |
+
# ===============================================================
|
| 724 |
+
# Helpers for split() method
|
| 725 |
+
# ===============================================================
|
| 726 |
+
|
| 727 |
+
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
|
| 728 |
+
"""
|
| 729 |
+
Finds nodes that consume module inputs or get_attr nodes.
|
| 730 |
+
"""
|
| 731 |
+
starter_cpu_nodes: NodeSet = set()
|
| 732 |
+
starter_acc_nodes: NodeSet = set()
|
| 733 |
+
for node in self.module.graph.nodes:
|
| 734 |
+
if node.op not in {"placeholder", "get_attr"}:
|
| 735 |
+
continue
|
| 736 |
+
for user in node.users:
|
| 737 |
+
if user in self.acc_nodes:
|
| 738 |
+
starter_acc_nodes.add(user)
|
| 739 |
+
else:
|
| 740 |
+
starter_cpu_nodes.add(user)
|
| 741 |
+
return starter_cpu_nodes, starter_acc_nodes
|
| 742 |
+
|
| 743 |
+
def put_nodes_into_subgraphs(self) -> List[Subgraph]:
|
| 744 |
+
# We start graph traversal from leaf nodes
|
| 745 |
+
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
|
| 746 |
+
visited_nodes: NodeSet = set()
|
| 747 |
+
|
| 748 |
+
# Determine which subgraph to start from based on which subgraph has
|
| 749 |
+
# 0-dep node
|
| 750 |
+
acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
|
| 751 |
+
|
| 752 |
+
current_subgraph_nodes: NodeList = []
|
| 753 |
+
|
| 754 |
+
# Result accumulator
|
| 755 |
+
subgraphs: List[Subgraph] = []
|
| 756 |
+
while current_cpu_nodes or current_acc_nodes:
|
| 757 |
+
# Find the first node that should belong to the current subgraph and has all dependencies resolved
|
| 758 |
+
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
|
| 759 |
+
node = next(
|
| 760 |
+
(n for n in current_nodes if self.deps[n] <= visited_nodes),
|
| 761 |
+
None,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# If nothing was found, then it's time to flip the mode and start a new subgraph
|
| 765 |
+
if node is None:
|
| 766 |
+
if not current_subgraph_nodes:
|
| 767 |
+
raise FxNetSplitterInternalError("Subgraph can't be empty")
|
| 768 |
+
|
| 769 |
+
subgraphs.append(
|
| 770 |
+
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
|
| 771 |
+
)
|
| 772 |
+
acc_subgraph = not acc_subgraph
|
| 773 |
+
current_subgraph_nodes = []
|
| 774 |
+
continue
|
| 775 |
+
|
| 776 |
+
current_nodes.remove(node)
|
| 777 |
+
visited_nodes.add(node)
|
| 778 |
+
current_subgraph_nodes.append(node)
|
| 779 |
+
|
| 780 |
+
# Add fusion buddies
|
| 781 |
+
if node in self.fusions:
|
| 782 |
+
if node in self.acc_nodes:
|
| 783 |
+
current_acc_nodes.update(self.fusions[node] - visited_nodes)
|
| 784 |
+
else:
|
| 785 |
+
current_cpu_nodes.update(self.fusions[node] - visited_nodes)
|
| 786 |
+
|
| 787 |
+
# Put depending nodes into the queue
|
| 788 |
+
for user in node.users:
|
| 789 |
+
if user.op not in CALLABLE_NODE_OPS:
|
| 790 |
+
continue
|
| 791 |
+
|
| 792 |
+
# Add downstream nodes
|
| 793 |
+
if user in self.acc_nodes:
|
| 794 |
+
current_acc_nodes.add(user)
|
| 795 |
+
else:
|
| 796 |
+
current_cpu_nodes.add(user)
|
| 797 |
+
|
| 798 |
+
# Check if the last subgraph was not created
|
| 799 |
+
if current_subgraph_nodes:
|
| 800 |
+
subgraphs.append(
|
| 801 |
+
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
if not subgraphs:
|
| 805 |
+
raise FxNetSplitterInternalError("Couldn't create subgraphs")
|
| 806 |
+
|
| 807 |
+
return subgraphs
|
| 808 |
+
|
| 809 |
+
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
|
| 810 |
+
"""
|
| 811 |
+
This pass finds ACC submodules with less than specified size and merges
|
| 812 |
+
them with adjacent CPU submodules.
|
| 813 |
+
"""
|
| 814 |
+
result: List[Subgraph] = []
|
| 815 |
+
for subgraph in subgraphs:
|
| 816 |
+
if subgraph.is_acc:
|
| 817 |
+
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
|
| 818 |
+
result.append(subgraph)
|
| 819 |
+
else:
|
| 820 |
+
print(
|
| 821 |
+
"Eliminating acc subgraph because it's smaller than the threshold: "
|
| 822 |
+
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
|
| 823 |
+
)
|
| 824 |
+
if result:
|
| 825 |
+
result[-1].nodes.extend(subgraph.nodes)
|
| 826 |
+
else:
|
| 827 |
+
subgraph.is_acc = False
|
| 828 |
+
result.append(subgraph)
|
| 829 |
+
else:
|
| 830 |
+
if result and not result[-1].is_acc:
|
| 831 |
+
result[-1].nodes.extend(subgraph.nodes)
|
| 832 |
+
else:
|
| 833 |
+
result.append(subgraph)
|
| 834 |
+
return result
|
| 835 |
+
|
| 836 |
+
def tag(self, subgraphs: List[Subgraph]):
|
| 837 |
+
self.tags: List[str] = []
|
| 838 |
+
for subgraph in subgraphs:
|
| 839 |
+
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
|
| 840 |
+
self.tags.append(tag)
|
| 841 |
+
for node in subgraph.nodes:
|
| 842 |
+
if hasattr(node, "tag"):
|
| 843 |
+
raise FxNetSplitterInternalError(f"Node {node} was already tagged")
|
| 844 |
+
|
| 845 |
+
node.tag = tag # type: ignore[attr-defined]
|
| 846 |
+
self._node_submodule_map[node.name] = tag
|
| 847 |
+
|
| 848 |
+
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
|
| 849 |
+
split_module = split_by_tags(self.module, self.tags)
|
| 850 |
+
if remove_tag:
|
| 851 |
+
for node in self.module.graph.nodes:
|
| 852 |
+
if hasattr(node, "tag"):
|
| 853 |
+
del node.tag
|
| 854 |
+
return split_module
|
| 855 |
+
|
| 856 |
+
def __call__(self) -> torch.fx.GraphModule:
|
| 857 |
+
subgraphs = self.put_nodes_into_subgraphs()
|
| 858 |
+
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
|
| 859 |
+
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
|
| 860 |
+
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
|
| 861 |
+
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
|
| 862 |
+
self.tag(subgraphs)
|
| 863 |
+
return self.split()
|
| 864 |
+
|
| 865 |
+
def generate_split_results(self) -> SplitResult:
|
| 866 |
+
split_module = self()
|
| 867 |
+
submodule_names = []
|
| 868 |
+
for name, mod in split_module.named_children():
|
| 869 |
+
submodule_names.append(name)
|
| 870 |
+
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
|
| 871 |
+
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
|
| 2 |
+
import collections
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.fx
|
| 7 |
+
from torch.fx.node import _get_qualified_name
|
| 8 |
+
from torch.fx._compatibility import compatibility
|
| 9 |
+
|
| 10 |
+
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
|
| 11 |
+
|
| 12 |
+
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
|
| 13 |
+
TensorOrTensors = Union[torch.Tensor, Tensors]
|
| 14 |
+
NodeList = List[torch.fx.Node]
|
| 15 |
+
NodeSet = Set[torch.fx.Node]
|
| 16 |
+
Names = List[str]
|
| 17 |
+
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@compatibility(is_backward_compatible=False)
|
| 21 |
+
def get_acc_ops_name(k):
|
| 22 |
+
if isinstance(k, str):
|
| 23 |
+
return k
|
| 24 |
+
elif k.__module__ and "acc_ops" in k.__module__:
|
| 25 |
+
return f"acc_ops.{k.__name__}"
|
| 26 |
+
else:
|
| 27 |
+
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
|
| 28 |
+
return f"{module if module else ''}.{k.__name__}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@compatibility(is_backward_compatible=False)
|
| 32 |
+
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
|
| 33 |
+
"""
|
| 34 |
+
Given a `node` returns its target typename.
|
| 35 |
+
|
| 36 |
+
For "call_method" node, return node.target which is the name of that method being called.
|
| 37 |
+
This could potential lead to conflict but should be okay because normally it's on a tensor.
|
| 38 |
+
|
| 39 |
+
For "call_function" node, return typename of node.target.
|
| 40 |
+
|
| 41 |
+
For "call_module" node, return typename of the module that node.target point to.
|
| 42 |
+
|
| 43 |
+
If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
|
| 44 |
+
"torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
assert node.op in CALLABLE_NODE_OPS, (
|
| 48 |
+
"Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if node.op == "call_module":
|
| 52 |
+
assert isinstance(node.target, str)
|
| 53 |
+
submod = submodules[node.target]
|
| 54 |
+
submod_type = getattr(submod, "_base_class_origin", type(submod))
|
| 55 |
+
return get_acc_ops_name(submod_type)
|
| 56 |
+
elif node.op == "call_function":
|
| 57 |
+
target: Any = node.target
|
| 58 |
+
return (
|
| 59 |
+
f"acc_ops.{target.__name__}"
|
| 60 |
+
if target.__module__ is not None and "acc_ops" in target.__module__
|
| 61 |
+
else _get_qualified_name(target)
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
assert isinstance(node.target, str)
|
| 65 |
+
return node.target
|
| 66 |
+
|
| 67 |
+
@compatibility(is_backward_compatible=False)
|
| 68 |
+
def is_node_output_tensor(node: torch.fx.Node) -> bool:
|
| 69 |
+
"""Checks if the node output produces a Tensor or not.
|
| 70 |
+
|
| 71 |
+
NOTE: This requires to run `ShapeProp` on the containing fx graph before
|
| 72 |
+
calling this function. This is because it works by checking the `type`
|
| 73 |
+
metadata on the node. This metadata is produced by the `ShapeProp`.
|
| 74 |
+
"""
|
| 75 |
+
type_ = node.meta.get("type", None)
|
| 76 |
+
return type_ is not None and issubclass(type_, torch.Tensor)
|
| 77 |
+
|
| 78 |
+
@compatibility(is_backward_compatible=False)
|
| 79 |
+
class FxNetAccFusionsFinder:
|
| 80 |
+
"""
|
| 81 |
+
Finds groups of connected ACC nodes that pass non-tensor data between each other.
|
| 82 |
+
Such groups are called fusion groups.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
|
| 86 |
+
self.module = module
|
| 87 |
+
self.nodes = list(module.graph.nodes)
|
| 88 |
+
self.acc_nodes = acc_nodes
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class FusionGroup:
|
| 92 |
+
# The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
|
| 93 |
+
top_node_idx: int
|
| 94 |
+
|
| 95 |
+
# Nodes in this fusion group.
|
| 96 |
+
nodes: NodeSet
|
| 97 |
+
|
| 98 |
+
# Inputs to this fusion group.
|
| 99 |
+
inputs: NodeSet
|
| 100 |
+
|
| 101 |
+
# Nodes that in the fusion group that haven't been processed yet.
|
| 102 |
+
nodes_need_process: NodeSet
|
| 103 |
+
|
| 104 |
+
def add_node(self, node):
|
| 105 |
+
"""
|
| 106 |
+
Add a node to fusion group.
|
| 107 |
+
"""
|
| 108 |
+
if node in self.nodes:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
self.nodes_need_process.add(node)
|
| 112 |
+
self.nodes.add(node)
|
| 113 |
+
self.inputs.discard(node)
|
| 114 |
+
self.inputs.update(
|
| 115 |
+
{
|
| 116 |
+
n
|
| 117 |
+
for n in node.all_input_nodes
|
| 118 |
+
if n.op in CALLABLE_NODE_OPS and n not in self.nodes
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def recursive_add_node(
|
| 123 |
+
self,
|
| 124 |
+
fusion_group: "FxNetAccFusionsFinder.FusionGroup",
|
| 125 |
+
inputs: Union[NodeSet, NodeList],
|
| 126 |
+
visited: Optional[NodeSet] = None,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Start from inputs and going reverse topological order. If any upstream node
|
| 130 |
+
is in the fusion group, add all the nodes in this path to fusion group.
|
| 131 |
+
"""
|
| 132 |
+
for arg in inputs:
|
| 133 |
+
# skip the node if already seen
|
| 134 |
+
if visited is not None:
|
| 135 |
+
if arg in visited:
|
| 136 |
+
continue
|
| 137 |
+
visited.add(arg)
|
| 138 |
+
|
| 139 |
+
# Skip placeholder and get_attr because they won't be in the fusion group.
|
| 140 |
+
if arg.op not in CALLABLE_NODE_OPS:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# If the node has smaller idx, it's already an upstream node of the fusion
|
| 144 |
+
# group. We don't need to check it anymore.
|
| 145 |
+
if self.nodes.index(arg) < fusion_group.top_node_idx:
|
| 146 |
+
continue
|
| 147 |
+
|
| 148 |
+
# If the node is in the fusion group, return True.
|
| 149 |
+
if arg in fusion_group.nodes:
|
| 150 |
+
return True
|
| 151 |
+
|
| 152 |
+
# Check the upstream nodes of the node, if any of them is in the fusion group
|
| 153 |
+
# we'll add this node to fusion group and return True.
|
| 154 |
+
if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
|
| 155 |
+
fusion_group.add_node(arg)
|
| 156 |
+
return True
|
| 157 |
+
|
| 158 |
+
return False
|
| 159 |
+
|
| 160 |
+
def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
|
| 161 |
+
result: Dict[torch.fx.Node, NodeSet] = {}
|
| 162 |
+
acc_nodes = list(self.acc_nodes)
|
| 163 |
+
|
| 164 |
+
for node in acc_nodes:
|
| 165 |
+
if node in result:
|
| 166 |
+
continue
|
| 167 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 168 |
+
continue
|
| 169 |
+
if "tensor_meta" in node.meta:
|
| 170 |
+
continue
|
| 171 |
+
if node not in self.acc_nodes:
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
|
| 175 |
+
top_node_idx=self.nodes.index(node),
|
| 176 |
+
nodes={node},
|
| 177 |
+
inputs=set(node.all_input_nodes),
|
| 178 |
+
nodes_need_process={node},
|
| 179 |
+
)
|
| 180 |
+
while fusion_group.nodes_need_process:
|
| 181 |
+
node = fusion_group.nodes_need_process.pop()
|
| 182 |
+
self.recursive_add_node(
|
| 183 |
+
fusion_group,
|
| 184 |
+
fusion_group.inputs,
|
| 185 |
+
visited=set(),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Optionally add downstream nodes
|
| 189 |
+
if "tensor_meta" not in node.meta:
|
| 190 |
+
for user in node.users:
|
| 191 |
+
if user.op not in CALLABLE_NODE_OPS:
|
| 192 |
+
continue
|
| 193 |
+
if user in fusion_group.nodes:
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
fusion_group.add_node(user)
|
| 197 |
+
self.recursive_add_node(
|
| 198 |
+
fusion_group,
|
| 199 |
+
fusion_group.inputs,
|
| 200 |
+
visited=set(),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Add some upstream nodes
|
| 204 |
+
for arg in node.all_input_nodes:
|
| 205 |
+
if arg.op not in CALLABLE_NODE_OPS:
|
| 206 |
+
continue
|
| 207 |
+
if "tensor_meta" in arg.meta:
|
| 208 |
+
continue
|
| 209 |
+
if arg in fusion_group.nodes:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
fusion_group.add_node(arg)
|
| 213 |
+
fusion_group.top_node_idx = min(
|
| 214 |
+
fusion_group.top_node_idx, self.nodes.index(arg)
|
| 215 |
+
)
|
| 216 |
+
self.recursive_add_node(
|
| 217 |
+
fusion_group,
|
| 218 |
+
fusion_group.inputs,
|
| 219 |
+
visited=set(),
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if not (set(fusion_group.nodes) <= self.acc_nodes):
|
| 223 |
+
self.acc_nodes -= fusion_group.nodes
|
| 224 |
+
else:
|
| 225 |
+
for n in fusion_group.nodes:
|
| 226 |
+
result[n] = fusion_group.nodes
|
| 227 |
+
|
| 228 |
+
return result
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@compatibility(is_backward_compatible=False)
|
| 232 |
+
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
| 233 |
+
"""
|
| 234 |
+
Replace the graph of the given GraphModule with one that contains the same nodes as the
|
| 235 |
+
original, but in topologically sorted order.
|
| 236 |
+
|
| 237 |
+
This is used by the merge_matmul transformation below, which disturbs the topologically sorted
|
| 238 |
+
order of its input GraphModule, so that this order is restored before further transformation.
|
| 239 |
+
|
| 240 |
+
Arguments:
|
| 241 |
+
gm: The graph module to topologically sort. It is modified in-place.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
The graph module in-place sorted
|
| 245 |
+
"""
|
| 246 |
+
indeg = dict.fromkeys(gm.graph.nodes, 0)
|
| 247 |
+
new_graph = torch.fx.Graph()
|
| 248 |
+
# Track how many unfulfilled dependencies each node has
|
| 249 |
+
for node in gm.graph.nodes:
|
| 250 |
+
for user in node.users:
|
| 251 |
+
indeg[user] += 1
|
| 252 |
+
queue: collections.deque = collections.deque()
|
| 253 |
+
# Add all nodes with no dependencies to the queue
|
| 254 |
+
for node in gm.graph.nodes:
|
| 255 |
+
if indeg[node] == 0:
|
| 256 |
+
queue.append(node)
|
| 257 |
+
env: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 258 |
+
# Pop nodes from the queue, and add nodes that have had all their
|
| 259 |
+
# dependencies fulfilled
|
| 260 |
+
while len(queue) > 0:
|
| 261 |
+
cur = queue.popleft()
|
| 262 |
+
env[cur] = new_graph.node_copy(cur, lambda x: env[x])
|
| 263 |
+
for user in cur.users:
|
| 264 |
+
indeg[user] -= 1
|
| 265 |
+
if indeg[user] == 0:
|
| 266 |
+
queue.append(user)
|
| 267 |
+
# If the new graph's size is not as large as the old one, then there must be
|
| 268 |
+
# a cycle (i.e. some node's dependencies were not satisfied.)
|
| 269 |
+
if len(new_graph.nodes) < len(gm.graph.nodes):
|
| 270 |
+
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
|
| 271 |
+
new_graph._codegen = gm.graph._codegen
|
| 272 |
+
gm.graph = new_graph
|
| 273 |
+
return gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/cpp.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Functionality for Python <-> C++ frontend inter-op."""
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class OrderedDictWrapper:
|
| 7 |
+
"""A wrapper around a C++ OrderedDict.
|
| 8 |
+
|
| 9 |
+
It dynamically evaluates the OrderedDict getter on a bound C++ module, such
|
| 10 |
+
that new changes on the C++ side are picked up. Otherwise accessing e.g.
|
| 11 |
+
``cpp_module._parameters`` just once would get a frozen copy of the parameters
|
| 12 |
+
at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
|
| 13 |
+
so using properties does not work.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, cpp_module, attr):
|
| 17 |
+
self.cpp_module = cpp_module
|
| 18 |
+
self.attr = attr
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def cpp_dict(self):
|
| 22 |
+
return getattr(self.cpp_module, self.attr)
|
| 23 |
+
|
| 24 |
+
# Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
|
| 25 |
+
# must manually override them.
|
| 26 |
+
|
| 27 |
+
def items(self):
|
| 28 |
+
return self.cpp_dict.items()
|
| 29 |
+
|
| 30 |
+
def keys(self):
|
| 31 |
+
return self.cpp_dict.keys()
|
| 32 |
+
|
| 33 |
+
def values(self):
|
| 34 |
+
return self.cpp_dict.values()
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
return self.cpp_dict.__iter__()
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return self.cpp_dict.__len__()
|
| 41 |
+
|
| 42 |
+
def __contains__(self, key):
|
| 43 |
+
return self.cpp_dict.__contains__(key)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, key):
|
| 46 |
+
return self.cpp_dict.__getitem__(key)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ModuleWrapper(nn.Module):
|
| 50 |
+
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, cpp_module):
|
| 53 |
+
# Assign before the super class constructor so ``self.training`` can be
|
| 54 |
+
# assigned to in the super class constructor.
|
| 55 |
+
self.cpp_module = cpp_module
|
| 56 |
+
super().__init__()
|
| 57 |
+
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
|
| 58 |
+
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
|
| 59 |
+
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
|
| 60 |
+
for attr in dir(cpp_module):
|
| 61 |
+
# Skip magic methods and the three attributes above.
|
| 62 |
+
if not attr.startswith("_"):
|
| 63 |
+
setattr(self, attr, getattr(self.cpp_module, attr))
|
| 64 |
+
|
| 65 |
+
def _apply(self, fn, recurse=True):
|
| 66 |
+
for param in self.parameters():
|
| 67 |
+
# Tensors stored in modules are graph leaves, and we don't
|
| 68 |
+
# want to create copy nodes, so we have to unpack the data.
|
| 69 |
+
param.data = fn(param.data)
|
| 70 |
+
if param._grad is not None:
|
| 71 |
+
param._grad.data = fn(param._grad.data)
|
| 72 |
+
|
| 73 |
+
for buf in self.buffers():
|
| 74 |
+
buf.data = fn(buf.data)
|
| 75 |
+
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
# nn.Module defines training as a boolean
|
| 79 |
+
@property # type: ignore[override]
|
| 80 |
+
def training(self):
|
| 81 |
+
return self.cpp_module.training
|
| 82 |
+
|
| 83 |
+
@training.setter
|
| 84 |
+
def training(self, mode):
|
| 85 |
+
self.cpp_module.train(mode)
|
| 86 |
+
|
| 87 |
+
def __repr__(self):
|
| 88 |
+
return self.cpp_module.__repr__()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/functional.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-311.pyc
ADDED
|
Binary file (713 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/qat/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401
|
| 2 |
+
r"""Intrinsic QAT Modules.
|
| 3 |
+
|
| 4 |
+
This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
|
| 5 |
+
is kept here for compatibility while the migration process is ongoing.
|
| 6 |
+
If you are adding a new entry/functionality, please, add it to the
|
| 7 |
+
appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
|
| 8 |
+
while adding an import statement here.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'LinearReLU',
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
from torch.ao.nn.intrinsic.qat import LinearReLU
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (435 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .modules import * # noqa: F403
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .linear_relu import LinearReLU
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'LinearReLU',
|
| 5 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-311.pyc
ADDED
|
Binary file (366 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/bn_relu.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized import BNReLU2d
|
| 2 |
+
from torch.ao.nn.intrinsic.quantized import BNReLU3d
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'BNReLU2d',
|
| 6 |
+
'BNReLU3d',
|
| 7 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/intrinsic/quantized/modules/linear_relu.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.ao.nn.intrinsic.quantized import LinearReLU
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'LinearReLU',
|
| 5 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/activation.cpython-311.pyc
ADDED
|
Binary file (73.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/adaptive.cpython-311.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/channelshuffle.cpython-311.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/distance.cpython-311.pyc
ADDED
|
Binary file (5.04 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/instancenorm.cpython-311.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/linear.cpython-311.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/modules/__pycache__/normalization.cpython-311.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|