Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py +28 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py +50 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py +258 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py +416 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py +134 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py +26 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py +21 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py +6 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py +11 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py +557 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +1279 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +1040 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +348 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py +125 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py +44 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py +112 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py +73 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py +421 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py +329 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py +217 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py +66 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py +195 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py +302 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py +95 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py +400 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc +0 -0
- tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py +87 -0
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch._export.db.examples as examples
|
| 5 |
+
|
| 6 |
+
TEMPLATE = '''import torch
|
| 7 |
+
|
| 8 |
+
from torch._export.db.case import export_case
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@export_case(
|
| 12 |
+
example_inputs=(torch.randn(3, 2),),
|
| 13 |
+
tags={{}},
|
| 14 |
+
)
|
| 15 |
+
def {case_name}(x):
|
| 16 |
+
"""
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
return
|
| 20 |
+
'''
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
assert len(sys.argv) == 2
|
| 24 |
+
root_dir = examples.__name__.replace(".", "/")
|
| 25 |
+
assert os.path.exists(root_dir)
|
| 26 |
+
with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f:
|
| 27 |
+
print("Writing to", f.name, "...")
|
| 28 |
+
f.write(TEMPLATE.format(case_name=sys.argv[1]))
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.fx
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# TODO(ycao): This is added to avoid breaking existing code temporarily.
|
| 9 |
+
# Remove when migration is done.
|
| 10 |
+
from torch.export.graph_signature import (
|
| 11 |
+
ExportBackwardSignature,
|
| 12 |
+
ExportGraphSignature,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from torch.export.exported_program import (
|
| 16 |
+
ExportedProgram,
|
| 17 |
+
ModuleCallEntry,
|
| 18 |
+
ModuleCallSignature,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"ExportBackwardSignature",
|
| 25 |
+
"ExportGraphSignature",
|
| 26 |
+
"ExportedProgram",
|
| 27 |
+
"ModuleCallEntry",
|
| 28 |
+
"ModuleCallSignature",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _create_graph_module_for_export(root, graph):
|
| 33 |
+
try:
|
| 34 |
+
gm = torch.fx.GraphModule(root, graph)
|
| 35 |
+
except SyntaxError:
|
| 36 |
+
# If custom objects stored in memory are being used in the graph,
|
| 37 |
+
# the generated python code will result in a syntax error on the custom
|
| 38 |
+
# object, since it is unable to parse the in-memory object. However
|
| 39 |
+
# we can still run the graph eagerly through torch.fx.Interpreter,
|
| 40 |
+
# so we will bypass this error.
|
| 41 |
+
warnings.warn(
|
| 42 |
+
"Unable to execute the generated python source code from "
|
| 43 |
+
"the graph. The graph module will no longer be directly callable, "
|
| 44 |
+
"but you can still run the ExportedProgram, and if needed, you can "
|
| 45 |
+
"run the graph module eagerly using torch.fx.Interpreter."
|
| 46 |
+
)
|
| 47 |
+
gm = torch.fx.GraphModule(root, torch.fx.Graph())
|
| 48 |
+
gm._graph = graph
|
| 49 |
+
|
| 50 |
+
return gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Any, Callable, Dict, List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch._dynamo.source import (
|
| 7 |
+
AttrSource,
|
| 8 |
+
GetItemSource,
|
| 9 |
+
LocalSource,
|
| 10 |
+
TensorProperty,
|
| 11 |
+
TensorPropertySource,
|
| 12 |
+
)
|
| 13 |
+
from torch._dynamo.variables.builder import TrackedFake
|
| 14 |
+
from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim
|
| 15 |
+
from torch._guards import Source
|
| 16 |
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
| 17 |
+
from torch.export import Constraint
|
| 18 |
+
from torch.export.graph_signature import CustomObjArgument
|
| 19 |
+
from torch.fx.experimental.symbolic_shapes import (
|
| 20 |
+
ConstraintViolationError,
|
| 21 |
+
DimDynamic,
|
| 22 |
+
EqualityConstraint,
|
| 23 |
+
ShapeEnv,
|
| 24 |
+
StatelessSymbolicContext,
|
| 25 |
+
)
|
| 26 |
+
from torch.utils._pytree import (
|
| 27 |
+
GetAttrKey,
|
| 28 |
+
KeyPath,
|
| 29 |
+
MappingKey,
|
| 30 |
+
SequenceKey,
|
| 31 |
+
tree_map_with_path,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def key_path_to_source(kp: KeyPath) -> Source:
|
| 36 |
+
"""
|
| 37 |
+
Given a key path, return the source for the key path.
|
| 38 |
+
"""
|
| 39 |
+
source: Source = LocalSource("args")
|
| 40 |
+
for k in kp:
|
| 41 |
+
if isinstance(k, SequenceKey):
|
| 42 |
+
source = GetItemSource(source, k.idx)
|
| 43 |
+
elif isinstance(k, MappingKey):
|
| 44 |
+
source = GetItemSource(source, k.key)
|
| 45 |
+
elif isinstance(k, GetAttrKey):
|
| 46 |
+
source = AttrSource(source, k.name)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unknown KeyEntry {k}")
|
| 49 |
+
|
| 50 |
+
return source
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _is_constant_argument(t):
|
| 54 |
+
return t is None or isinstance(t, (int, float, bool, str))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def fakify(
|
| 58 |
+
mode: FakeTensorMode,
|
| 59 |
+
kp: KeyPath,
|
| 60 |
+
t: Any,
|
| 61 |
+
t_constraints: Dict[int, Dict[int, Constraint]],
|
| 62 |
+
sources: Dict[Tuple[int, int], List[Source]],
|
| 63 |
+
):
|
| 64 |
+
source = key_path_to_source(kp)
|
| 65 |
+
if _is_constant_argument(t) or isinstance(t, torch.ScriptObject):
|
| 66 |
+
return t
|
| 67 |
+
if not isinstance(t, torch.Tensor):
|
| 68 |
+
raise ValueError(f"Unsupported input type {type(t)}")
|
| 69 |
+
n_dims = len(t.shape)
|
| 70 |
+
symbolic_context = StatelessSymbolicContext(
|
| 71 |
+
dynamic_sizes=[DimDynamic.STATIC] * n_dims,
|
| 72 |
+
constraint_sizes=[None] * n_dims,
|
| 73 |
+
)
|
| 74 |
+
t_id = id(t)
|
| 75 |
+
if t_id in t_constraints:
|
| 76 |
+
for i, constraint in t_constraints[t_id].items():
|
| 77 |
+
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
| 78 |
+
symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC
|
| 79 |
+
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
|
| 80 |
+
sources[(t_id, i)].append(src)
|
| 81 |
+
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name
|
| 82 |
+
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
|
| 83 |
+
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context))
|
| 84 |
+
return fake
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def make_fake_params_buffers(
|
| 88 |
+
fake_mode: FakeTensorMode,
|
| 89 |
+
params_buffers: Dict[str, torch.Tensor],
|
| 90 |
+
) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]:
|
| 91 |
+
faked_params_buffers = {}
|
| 92 |
+
for key, value in params_buffers.items():
|
| 93 |
+
faked_params_buffers[key] = fake_mode.from_tensor(value, static_shapes=True)
|
| 94 |
+
return faked_params_buffers
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def make_fake_inputs(nn_module, args, kwargs, constraints):
|
| 98 |
+
"""
|
| 99 |
+
Given an nn module, example inputs, and constraints, return a new fake mode,
|
| 100 |
+
fake inputs created in that mode whose dynamic shape dimensions are constrained
|
| 101 |
+
by the given ranges, and sources for pairs of dynamic shape dimensions that are
|
| 102 |
+
constrained to be equal.
|
| 103 |
+
"""
|
| 104 |
+
# TODO(avik): refactor Dynamo to avoid duplication of the following code
|
| 105 |
+
# between non-strict and strict.
|
| 106 |
+
# Specifically, here (non-strict) we do the following pre-tracing steps:
|
| 107 |
+
# - Fakify inputs.
|
| 108 |
+
# - Process input shape equalities.
|
| 109 |
+
# In strict, these steps are spread across multiple files:
|
| 110 |
+
# - output_graph.py fakifies inputs.
|
| 111 |
+
# - [post-tracing] guards.py processes input shape equalities.
|
| 112 |
+
|
| 113 |
+
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
|
| 114 |
+
for constraint in constraints:
|
| 115 |
+
t_constraints[constraint.t_id][constraint.dim] = constraint
|
| 116 |
+
if constraint.shared is not None:
|
| 117 |
+
t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint
|
| 118 |
+
|
| 119 |
+
code = nn_module.forward.__code__
|
| 120 |
+
co_fields = {
|
| 121 |
+
"co_name": code.co_name,
|
| 122 |
+
"co_filename": code.co_filename,
|
| 123 |
+
"co_firstlineno": code.co_firstlineno,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
fake_mode = FakeTensorMode(
|
| 127 |
+
shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields),
|
| 128 |
+
allow_non_fake_inputs=True,
|
| 129 |
+
)
|
| 130 |
+
if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None:
|
| 131 |
+
raise ValueError(
|
| 132 |
+
"Detected fake_mode does not have a shape_env with tracked fakes. "
|
| 133 |
+
"If you constructed the module under a FakeTensorMode, "
|
| 134 |
+
"please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
with fake_mode:
|
| 138 |
+
original_signature = inspect.signature(nn_module.forward)
|
| 139 |
+
sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list)
|
| 140 |
+
fake_args, fake_kwargs = tree_map_with_path(
|
| 141 |
+
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
|
| 142 |
+
(args, kwargs),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
from sympy import Symbol
|
| 146 |
+
|
| 147 |
+
source_pairs: List[Tuple[Source, Source]] = []
|
| 148 |
+
derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = []
|
| 149 |
+
phantom_symbols: Dict[str, Symbol] = {}
|
| 150 |
+
for constraint in constraints:
|
| 151 |
+
torch.export.dynamic_shapes._process_equalities(
|
| 152 |
+
constraint,
|
| 153 |
+
lambda t_id, dim: sources[(t_id, dim)],
|
| 154 |
+
fake_mode.shape_env,
|
| 155 |
+
source_pairs,
|
| 156 |
+
derived_equalities,
|
| 157 |
+
phantom_symbols,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
equalities_inputs = EqualityConstraint(
|
| 161 |
+
source_pairs=source_pairs,
|
| 162 |
+
derived_equalities=derived_equalities,
|
| 163 |
+
phantom_symbols=list(phantom_symbols.values()),
|
| 164 |
+
warn_only=False,
|
| 165 |
+
)
|
| 166 |
+
return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def make_constraints(
|
| 170 |
+
fake_mode,
|
| 171 |
+
equalities_inputs,
|
| 172 |
+
original_signature,
|
| 173 |
+
gm,
|
| 174 |
+
):
|
| 175 |
+
"""
|
| 176 |
+
Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
|
| 177 |
+
and a graph module, produce guards on the fake mode's shape env (raising constraint
|
| 178 |
+
violations if any), solve (to suggest simplifications or fixes), and return the
|
| 179 |
+
resulting range constraints and equality constraints.
|
| 180 |
+
"""
|
| 181 |
+
# TODO(avik): refactor Dynamo to avoid duplication of the following code
|
| 182 |
+
# between non-strict and strict.
|
| 183 |
+
# Specifically, here (non-strict) we do the following post-tracing steps:
|
| 184 |
+
# - Produce guards.
|
| 185 |
+
# - Solve constraints.
|
| 186 |
+
# - Install shape metadata in IR.
|
| 187 |
+
# In strict, these steps are spread across multiple files:
|
| 188 |
+
# - guards.py produces guards.
|
| 189 |
+
# - eval_frame.py solves constraints
|
| 190 |
+
# - _trace.py installs shape metadata in IR.
|
| 191 |
+
|
| 192 |
+
shape_env = fake_mode.shape_env
|
| 193 |
+
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
| 194 |
+
sources = [tf.source for tf in shape_env.tracked_fakes]
|
| 195 |
+
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
|
| 196 |
+
constraint_violation_error = None
|
| 197 |
+
try:
|
| 198 |
+
shape_env.produce_guards(
|
| 199 |
+
placeholders,
|
| 200 |
+
sources,
|
| 201 |
+
input_contexts=input_contexts,
|
| 202 |
+
equalities_inputs=equalities_inputs,
|
| 203 |
+
ignore_static=False,
|
| 204 |
+
)
|
| 205 |
+
except ConstraintViolationError as e:
|
| 206 |
+
constraint_violation_error = e
|
| 207 |
+
|
| 208 |
+
shape_env.frozen = True
|
| 209 |
+
dim_constraints = shape_env.dim_constraints
|
| 210 |
+
if dim_constraints is None:
|
| 211 |
+
# Expected when shape_env.produce_guards throws an early constraint violation error.
|
| 212 |
+
# There is nothing to solve for in this case.
|
| 213 |
+
# TODO(avik): Maybe record the constraint violation error instead and replay later?
|
| 214 |
+
assert constraint_violation_error
|
| 215 |
+
raise constraint_violation_error
|
| 216 |
+
dim_constraints.solve()
|
| 217 |
+
dim_constraints.remove_redundant_dynamic_results()
|
| 218 |
+
forced_specializations = dim_constraints.forced_specializations()
|
| 219 |
+
msg = dim_constraints.prettify_results(
|
| 220 |
+
original_signature, constraint_violation_error, forced_specializations
|
| 221 |
+
)
|
| 222 |
+
if constraint_violation_error:
|
| 223 |
+
constraint_violation_error.args = (constraint_violation_error.args[0] + msg,)
|
| 224 |
+
elif forced_specializations:
|
| 225 |
+
constraint_violation_error = ConstraintViolationError(msg)
|
| 226 |
+
if constraint_violation_error:
|
| 227 |
+
raise constraint_violation_error
|
| 228 |
+
|
| 229 |
+
range_constraints = {}
|
| 230 |
+
input_dims = defaultdict(list)
|
| 231 |
+
free_symbols = set()
|
| 232 |
+
for node in gm.graph.nodes:
|
| 233 |
+
if node.op != "placeholder":
|
| 234 |
+
continue
|
| 235 |
+
if _is_constant_argument(node.meta["val"]) or isinstance(
|
| 236 |
+
node.meta["val"], CustomObjArgument
|
| 237 |
+
):
|
| 238 |
+
continue
|
| 239 |
+
for i, d in enumerate(node.meta["val"].shape):
|
| 240 |
+
if isinstance(d, torch.SymInt):
|
| 241 |
+
# Look up the range constraint for the symbol corresponding to this shape dimension
|
| 242 |
+
# and store it indexed by the symbolic expression corresponding to it.
|
| 243 |
+
# NOTE(avik): Use node._expr instead of node.expr for the lookup here because
|
| 244 |
+
# we want the symbol, not its replacement, which could be an expression. Maybe
|
| 245 |
+
# there's a better way to do this, e.g., by (re)computing value ranges for expressions?
|
| 246 |
+
range_constraints[d.node.expr] = shape_env.var_to_range[d.node._expr]
|
| 247 |
+
input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
|
| 248 |
+
free_symbols.update(d.node.expr.free_symbols)
|
| 249 |
+
|
| 250 |
+
for symbol in free_symbols:
|
| 251 |
+
if symbol not in range_constraints:
|
| 252 |
+
# Placeholders can have symbolic shapes that are derived expressions.
|
| 253 |
+
# The above code will record direct range constraints for them
|
| 254 |
+
# so that we can do runtime assertions. In addition, for serde checks
|
| 255 |
+
# we want to record range constraints for their root symbols.
|
| 256 |
+
range_constraints[symbol] = shape_env.var_to_range[symbol]
|
| 257 |
+
|
| 258 |
+
return range_constraints
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
import operator
|
| 4 |
+
from collections.abc import Iterable
|
| 5 |
+
from typing import Any, Dict, final, List, Optional, Tuple, Type
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch._ops import HigherOrderOperator, OpOverload
|
| 9 |
+
from torch._subclasses.fake_tensor import FakeTensor
|
| 10 |
+
from torch.export.exported_program import ExportedProgram
|
| 11 |
+
from torch.export.graph_signature import (
|
| 12 |
+
CustomObjArgument,
|
| 13 |
+
InputKind,
|
| 14 |
+
SymIntArgument,
|
| 15 |
+
TensorArgument,
|
| 16 |
+
)
|
| 17 |
+
from torch.fx import GraphModule
|
| 18 |
+
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class SpecViolationError(Exception):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def is_functional(op: OpOverload) -> bool:
|
| 26 |
+
return not op._schema.is_mutable
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
|
| 30 |
+
# TODO(angelayi): remove this in favor of _check_val
|
| 31 |
+
return _check_val(node)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _check_val(node: torch.fx.Node) -> None:
|
| 35 |
+
def _check_correct_val(val):
|
| 36 |
+
if val is None:
|
| 37 |
+
return True
|
| 38 |
+
elif isinstance(val, (int, bool, str, float)):
|
| 39 |
+
return True
|
| 40 |
+
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
|
| 41 |
+
return True
|
| 42 |
+
elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor.
|
| 43 |
+
return True
|
| 44 |
+
elif isinstance(val, (SymInt, SymFloat, SymBool)):
|
| 45 |
+
return True
|
| 46 |
+
elif isinstance(val, CustomObjArgument):
|
| 47 |
+
return True
|
| 48 |
+
elif isinstance(val, Iterable):
|
| 49 |
+
return all(_check_correct_val(x) for x in val)
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
def _no_returns(op):
|
| 53 |
+
if not isinstance(op, OpOverload):
|
| 54 |
+
return False
|
| 55 |
+
return len(op._schema.returns) == 0
|
| 56 |
+
|
| 57 |
+
if "val" not in node.meta:
|
| 58 |
+
if node.op == "call_function" and _no_returns(node.target):
|
| 59 |
+
return
|
| 60 |
+
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
|
| 61 |
+
|
| 62 |
+
val = node.meta["val"]
|
| 63 |
+
if not _check_correct_val(val):
|
| 64 |
+
raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class _VerifierMeta(type):
|
| 68 |
+
_registry: Dict[str, Type['Verifier']] = {}
|
| 69 |
+
|
| 70 |
+
def __new__(metacls, name, bases, attrs):
|
| 71 |
+
if bases:
|
| 72 |
+
if "check" in attrs or "_check_graph_module" in attrs:
|
| 73 |
+
raise SyntaxError("Overriding method check is not allowed.")
|
| 74 |
+
assert "dialect" in attrs and attrs["dialect"] != "ATEN"
|
| 75 |
+
else:
|
| 76 |
+
assert "check" in attrs
|
| 77 |
+
assert "_check_graph_module" in attrs
|
| 78 |
+
assert attrs["dialect"] == "ATEN"
|
| 79 |
+
|
| 80 |
+
assert isinstance(attrs["dialect"], str)
|
| 81 |
+
ret = type.__new__(metacls, name, bases, attrs)
|
| 82 |
+
metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment]
|
| 83 |
+
return ret
|
| 84 |
+
|
| 85 |
+
def getattr_recursive(obj: Any, target: str) -> Any:
|
| 86 |
+
target_atoms = target.split('.')
|
| 87 |
+
attr_itr = obj
|
| 88 |
+
for i, atom in enumerate(target_atoms):
|
| 89 |
+
if not hasattr(attr_itr, atom):
|
| 90 |
+
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
|
| 91 |
+
attr_itr = getattr(attr_itr, atom)
|
| 92 |
+
return attr_itr
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Verifier(metaclass=_VerifierMeta):
|
| 96 |
+
dialect = "ATEN"
|
| 97 |
+
|
| 98 |
+
def allowed_builtin_ops(self) -> List:
|
| 99 |
+
return [
|
| 100 |
+
operator.getitem,
|
| 101 |
+
operator.add,
|
| 102 |
+
operator.mul,
|
| 103 |
+
operator.sub,
|
| 104 |
+
operator.truediv,
|
| 105 |
+
operator.ge,
|
| 106 |
+
operator.le,
|
| 107 |
+
operator.gt,
|
| 108 |
+
operator.lt,
|
| 109 |
+
operator.eq,
|
| 110 |
+
operator.ne,
|
| 111 |
+
operator.floordiv,
|
| 112 |
+
operator.mod,
|
| 113 |
+
operator.and_,
|
| 114 |
+
operator.or_,
|
| 115 |
+
operator.not_,
|
| 116 |
+
operator.pow,
|
| 117 |
+
operator.neg,
|
| 118 |
+
operator.abs,
|
| 119 |
+
math.ceil,
|
| 120 |
+
math.floor,
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
|
| 124 |
+
return (OpOverload, HigherOrderOperator)
|
| 125 |
+
|
| 126 |
+
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
|
| 127 |
+
return (torch.fx.GraphModule,)
|
| 128 |
+
|
| 129 |
+
def check_valid_op(self, op):
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
def check_additional(self, gm: GraphModule) -> None:
|
| 133 |
+
"""
|
| 134 |
+
Additional checks that are specific to some dialects.
|
| 135 |
+
"""
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
@final
|
| 139 |
+
def check(self, ep: ExportedProgram) -> None:
|
| 140 |
+
self._check_graph_module(ep.graph_module)
|
| 141 |
+
_verify_exported_program_signature(ep)
|
| 142 |
+
|
| 143 |
+
@final
|
| 144 |
+
def _check_graph_module(self, gm: torch.fx.GraphModule) -> None:
|
| 145 |
+
def _allowed_getattr_types() -> Tuple[Type[Any], ...]:
|
| 146 |
+
ret = self.allowed_getattr_types()
|
| 147 |
+
assert not any(t is object for t in ret)
|
| 148 |
+
return ret
|
| 149 |
+
|
| 150 |
+
def _check_valid_op(op) -> None:
|
| 151 |
+
def _allowed_builtin_ops() -> List:
|
| 152 |
+
ret = self.allowed_builtin_ops()
|
| 153 |
+
assert all(inspect.isbuiltin(op) for op in ret)
|
| 154 |
+
return ret
|
| 155 |
+
|
| 156 |
+
def _allowed_op_types() -> Tuple[Type[Any], ...]:
|
| 157 |
+
ret = self.allowed_op_types()
|
| 158 |
+
assert not any(t is object for t in ret)
|
| 159 |
+
return ret
|
| 160 |
+
|
| 161 |
+
# TODO Remove this allowlist.
|
| 162 |
+
_allowed_torch_functions = (
|
| 163 |
+
torch.autograd.grad_mode.set_grad_enabled,
|
| 164 |
+
torch.sym_int,
|
| 165 |
+
torch.sym_ite,
|
| 166 |
+
torch.sym_max,
|
| 167 |
+
torch.sym_min,
|
| 168 |
+
torch.sym_not,
|
| 169 |
+
torch.sym_sqrt,
|
| 170 |
+
# TODO (tmanlaibaatar)
|
| 171 |
+
# Predispatch export is able to contain autograd ops.
|
| 172 |
+
# These will be modeled as HOO later
|
| 173 |
+
torch._C._set_grad_enabled
|
| 174 |
+
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if not isinstance(op, _allowed_op_types()):
|
| 178 |
+
if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions:
|
| 179 |
+
raise SpecViolationError(
|
| 180 |
+
f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n"
|
| 181 |
+
f"Valid builtin ops: {_allowed_builtin_ops()}"
|
| 182 |
+
f"Valid torch functions: {_allowed_torch_functions}"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if isinstance(op, OpOverload):
|
| 186 |
+
# All ops functional
|
| 187 |
+
if not is_functional(op):
|
| 188 |
+
raise SpecViolationError(
|
| 189 |
+
f"operator '{op}' is not functional"
|
| 190 |
+
)
|
| 191 |
+
self.check_valid_op(op)
|
| 192 |
+
|
| 193 |
+
for mod in gm.modules():
|
| 194 |
+
if not isinstance(mod, torch.fx.GraphModule):
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
mod.graph.lint()
|
| 198 |
+
for node in mod.graph.nodes:
|
| 199 |
+
# TODO(T140410192): should have fake tensor for all dialects
|
| 200 |
+
if node.op in {"call_module", "call_method"}:
|
| 201 |
+
raise SpecViolationError(
|
| 202 |
+
f"call_module is not valid: got a class '{node.target}' ",
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
elif node.op == "call_function":
|
| 206 |
+
_check_val(node)
|
| 207 |
+
|
| 208 |
+
_check_valid_op(node.target)
|
| 209 |
+
|
| 210 |
+
elif node.op == "get_attr":
|
| 211 |
+
if not isinstance(node.target, str):
|
| 212 |
+
raise SpecViolationError(
|
| 213 |
+
f"Expected get_attr target to be string, but got {type(node.target)}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
attr = getattr_recursive(mod, node.target)
|
| 217 |
+
if isinstance(attr, torch.nn.Module):
|
| 218 |
+
def _is_type(name, ty):
|
| 219 |
+
return isinstance(getattr(attr, name, None), ty)
|
| 220 |
+
if type(attr).__name__ == "LoweredBackendModule":
|
| 221 |
+
if _is_type("backend_id", str) \
|
| 222 |
+
and _is_type("processed_bytes", bytes) \
|
| 223 |
+
and _is_type("compile_specs", list) \
|
| 224 |
+
and hasattr(attr, "original_module"):
|
| 225 |
+
continue
|
| 226 |
+
else:
|
| 227 |
+
backend_id = getattr(attr, "backend_id", None)
|
| 228 |
+
processed_bytes = getattr(attr, "processed_bytes", None)
|
| 229 |
+
compile_specs = getattr(attr, "compile_specs", None)
|
| 230 |
+
raise SpecViolationError(
|
| 231 |
+
f"Invalid get_attr type {type(attr)}. \n"
|
| 232 |
+
f"LoweredBackendModule fields: "
|
| 233 |
+
f"backend_id(str) : {type(backend_id)}, "
|
| 234 |
+
f"processed_bytes(bytes) : {type(processed_bytes)}, "
|
| 235 |
+
f"compile_specs(list) : {type(compile_specs)}"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if not isinstance(attr, _allowed_getattr_types()):
|
| 239 |
+
raise SpecViolationError(
|
| 240 |
+
f"Invalid get_attr type {type(attr)}. \n"
|
| 241 |
+
f"Valid get_attr types: {_allowed_getattr_types()}"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
elif node.op == "placeholder":
|
| 246 |
+
_check_val(node)
|
| 247 |
+
# TODO(zhxchen17)
|
| 248 |
+
# elif node.op == "output":
|
| 249 |
+
# _check_flattened_outputs()
|
| 250 |
+
|
| 251 |
+
self.check_additional(gm)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _verify_exported_program_signature(exported_program) -> None:
|
| 255 |
+
# Check ExportedProgram signature matches
|
| 256 |
+
gs = exported_program.graph_signature
|
| 257 |
+
|
| 258 |
+
# Check every node in the signature exists in the graph
|
| 259 |
+
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
|
| 260 |
+
|
| 261 |
+
if len(input_node_names) != len(gs.input_specs):
|
| 262 |
+
raise SpecViolationError(
|
| 263 |
+
f"Number of graph inputs ({len(input_node_names)}) "
|
| 264 |
+
f"does not match number of inputs in the graph signature ({len(gs.user_inputs)})"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
for input_spec, node in zip(gs.input_specs, input_node_names):
|
| 268 |
+
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
|
| 269 |
+
if input_spec.arg.name != node:
|
| 270 |
+
raise SpecViolationError(
|
| 271 |
+
f"Input spec name {input_spec.arg.name} does not match node name {node}"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if input_spec.kind == InputKind.USER_INPUT:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
elif input_spec.kind == InputKind.PARAMETER:
|
| 278 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 279 |
+
raise SpecViolationError(
|
| 280 |
+
f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 281 |
+
)
|
| 282 |
+
if input_spec.target is None:
|
| 283 |
+
raise SpecViolationError(
|
| 284 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
param = input_spec.target
|
| 288 |
+
if param not in exported_program.state_dict:
|
| 289 |
+
raise SpecViolationError(
|
| 290 |
+
f"Parameter {param} is not in the state dict."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
|
| 294 |
+
raise SpecViolationError(
|
| 295 |
+
f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
elif input_spec.kind == InputKind.BUFFER:
|
| 299 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 300 |
+
raise SpecViolationError(
|
| 301 |
+
f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 302 |
+
)
|
| 303 |
+
if input_spec.target is None:
|
| 304 |
+
raise SpecViolationError(
|
| 305 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
buffer = input_spec.target
|
| 309 |
+
if input_spec.persistent is None:
|
| 310 |
+
raise SpecViolationError(
|
| 311 |
+
f"Buffer {buffer} is missing a persistence flag"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if input_spec.persistent is True and buffer not in exported_program.state_dict:
|
| 315 |
+
raise SpecViolationError(
|
| 316 |
+
f"Buffer {buffer} is not in the state dict."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
if input_spec.persistent is False and buffer in exported_program.state_dict:
|
| 320 |
+
raise SpecViolationError(
|
| 321 |
+
f"Non-persistent buffer {buffer} is in the state dict, it should not be."
|
| 322 |
+
)
|
| 323 |
+
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
|
| 324 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 325 |
+
raise SpecViolationError(
|
| 326 |
+
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 327 |
+
)
|
| 328 |
+
if input_spec.target is None:
|
| 329 |
+
raise SpecViolationError(
|
| 330 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
tensor_const = input_spec.target
|
| 334 |
+
if tensor_const not in exported_program.constants:
|
| 335 |
+
raise SpecViolationError(
|
| 336 |
+
f"Constant tensor {tensor_const} is not in the constants dictionary."
|
| 337 |
+
)
|
| 338 |
+
elif input_spec.kind == InputKind.CUSTOM_OBJ:
|
| 339 |
+
if not isinstance(input_spec.arg, CustomObjArgument):
|
| 340 |
+
raise SpecViolationError(
|
| 341 |
+
f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead."
|
| 342 |
+
)
|
| 343 |
+
if input_spec.target is None:
|
| 344 |
+
raise SpecViolationError(
|
| 345 |
+
f"InputSpec for {input_spec.name} has no target."
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
custom_obj = input_spec.target
|
| 349 |
+
if custom_obj not in exported_program.constants:
|
| 350 |
+
raise SpecViolationError(
|
| 351 |
+
f"Custom object {custom_obj} is not in the constants dictionary."
|
| 352 |
+
)
|
| 353 |
+
elif input_spec.kind == InputKind.TOKEN:
|
| 354 |
+
if not isinstance(input_spec.arg, TensorArgument):
|
| 355 |
+
raise SpecViolationError(
|
| 356 |
+
f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead."
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
raise SpecViolationError(
|
| 360 |
+
f"Unknown InputKind {input_spec.kind}."
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Check outputs
|
| 364 |
+
output_node = list(exported_program.graph.nodes)[-1]
|
| 365 |
+
assert output_node.op == "output"
|
| 366 |
+
output_nodes = [
|
| 367 |
+
arg.name if isinstance(arg, torch.fx.Node) else arg
|
| 368 |
+
for arg in output_node.args[0]
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
if len(output_nodes) != len(gs.output_specs):
|
| 372 |
+
raise SpecViolationError(
|
| 373 |
+
f"Number of output nodes {len(output_nodes)} is different "
|
| 374 |
+
"Than the number of outputs specified by the graph signature: \n"
|
| 375 |
+
f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
|
| 376 |
+
f"Number of user outputs: {len(gs.user_outputs)}. \n"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
num_tokens = len(gs.output_tokens)
|
| 380 |
+
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens
|
| 381 |
+
mutate_nodes: List[str] = output_nodes[num_tokens:end]
|
| 382 |
+
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
| 383 |
+
|
| 384 |
+
for mutation_node in mutate_nodes:
|
| 385 |
+
if mutation_node in gs.buffers_to_mutate:
|
| 386 |
+
if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
|
| 387 |
+
raise SpecViolationError(
|
| 388 |
+
f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
|
| 389 |
+
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
| 390 |
+
f"Buffer nodes available: {gs.buffers} \n"
|
| 391 |
+
)
|
| 392 |
+
elif mutation_node in gs.user_inputs_to_mutate:
|
| 393 |
+
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
|
| 394 |
+
raise SpecViolationError(
|
| 395 |
+
f"User input output {mutation_node} does not point to a user input that exists. \n"
|
| 396 |
+
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
|
| 397 |
+
f"User input nodes available: {gs.user_inputs} \n")
|
| 398 |
+
else:
|
| 399 |
+
raise SpecViolationError(
|
| 400 |
+
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
|
| 401 |
+
f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
|
| 405 |
+
if user_output_node != user_output_name:
|
| 406 |
+
raise SpecViolationError(
|
| 407 |
+
f"User output {user_output_node} is not in the correct "
|
| 408 |
+
"order or is not found in the "
|
| 409 |
+
f"exported program's user_output list: {gs.user_outputs}. "
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def load_verifier(dialect: str) -> Optional[Type[Verifier]]:
|
| 414 |
+
if dialect == "ATEN":
|
| 415 |
+
return _VerifierMeta._registry.get(dialect)
|
| 416 |
+
return _VerifierMeta._registry[dialect]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import threading
|
| 3 |
+
from queue import Empty as EmptyQueue, Queue
|
| 4 |
+
|
| 5 |
+
from torch._lazy.device_context import get_device_context
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ClosureHandler:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def run(self, closure):
|
| 13 |
+
"""Run closure function
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
closure: callable function to run
|
| 17 |
+
"""
|
| 18 |
+
closure()
|
| 19 |
+
|
| 20 |
+
def __call__(self, closures):
|
| 21 |
+
for closure in closures:
|
| 22 |
+
self.run(closure)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AsyncClosureHandler(ClosureHandler):
|
| 26 |
+
"""Handler for Asynchronous Step Closures
|
| 27 |
+
Args:
|
| 28 |
+
max_queue_size: The maximum length of the closure queue after which
|
| 29 |
+
the training loop will block until closures are evaluated.
|
| 30 |
+
By default, a reasonable limit of a maximum of 100 on the queue.
|
| 31 |
+
This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment
|
| 32 |
+
variable.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, max_queue_size=100):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self._closure_queue: Queue = Queue(
|
| 38 |
+
int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size))
|
| 39 |
+
)
|
| 40 |
+
self._closure_exception: Queue = Queue()
|
| 41 |
+
self._closure_lock = threading.Lock()
|
| 42 |
+
self._closure_event_loop_finished = threading.Event()
|
| 43 |
+
self._closure_event_loop = None
|
| 44 |
+
|
| 45 |
+
def start_event_loop(self):
|
| 46 |
+
"""Start closure event loop if not started"""
|
| 47 |
+
if self._closure_event_loop is None:
|
| 48 |
+
|
| 49 |
+
def event_loop():
|
| 50 |
+
# Run loop until closure event is set and closure queue is empty
|
| 51 |
+
while True:
|
| 52 |
+
try:
|
| 53 |
+
closure = self._closure_queue.get(block=True, timeout=3)
|
| 54 |
+
closure()
|
| 55 |
+
self._closure_queue.task_done()
|
| 56 |
+
except EmptyQueue:
|
| 57 |
+
with self._closure_lock:
|
| 58 |
+
if self._closure_queue.empty():
|
| 59 |
+
self._closure_event_loop_finished.set()
|
| 60 |
+
return
|
| 61 |
+
except Exception as e:
|
| 62 |
+
self._closure_exception.put(e)
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
self._closure_event_loop = threading.Thread(target=event_loop)
|
| 66 |
+
self._closure_event_loop.start()
|
| 67 |
+
|
| 68 |
+
def run(self, closure):
|
| 69 |
+
with self._closure_lock:
|
| 70 |
+
self._closure_queue.put(closure, block=True)
|
| 71 |
+
if (
|
| 72 |
+
self._closure_event_loop is None
|
| 73 |
+
or not self._closure_event_loop.is_alive()
|
| 74 |
+
):
|
| 75 |
+
try:
|
| 76 |
+
e = self._closure_exception.get(block=False)
|
| 77 |
+
raise RuntimeError(
|
| 78 |
+
"Cannot run asynchronous closure due to previously raised exception"
|
| 79 |
+
) from e
|
| 80 |
+
except EmptyQueue:
|
| 81 |
+
self._closure_event_loop = None
|
| 82 |
+
self.start_event_loop()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def add_step_closure(closure, args=(), run_async=False):
|
| 86 |
+
"""Adds a closure to the list of the ones to be run at the end of the step.
|
| 87 |
+
Many times during model training there is the need to print/report (print to
|
| 88 |
+
console, post to tensorboard, etc...) information which require the content of
|
| 89 |
+
intermediary tensors to be inspected.
|
| 90 |
+
Inspecting different tensors content in different points of the model code
|
| 91 |
+
requires many executions and typically causes performance issues.
|
| 92 |
+
Adding a step closure will ensure that it will be run after the barrier, when
|
| 93 |
+
all the live tensors will be already materialized to device data.
|
| 94 |
+
Live tensors which will include the ones captured by the closure arguments.
|
| 95 |
+
So using `add_step_closure()` will ensure a single execution will be
|
| 96 |
+
performed, even when multiple closures are queued, requiring multiple tensors
|
| 97 |
+
to be inspected.
|
| 98 |
+
Step closures will be run sequentially in the order they have been queued.
|
| 99 |
+
Note that even though using this API the execution will be optimized, it is
|
| 100 |
+
advised to throttle the printing/reporting events once every N steps.
|
| 101 |
+
Args:
|
| 102 |
+
closure (callable): The function to be called.
|
| 103 |
+
args (tuple): The arguments to be passed to the closure.
|
| 104 |
+
run_async: If True, run the closure asynchronously.
|
| 105 |
+
"""
|
| 106 |
+
devctx = get_device_context()
|
| 107 |
+
closures_type = "async_step_closures" if run_async else "step_closures"
|
| 108 |
+
step_closures = getattr(devctx, closures_type, None)
|
| 109 |
+
if step_closures is None:
|
| 110 |
+
step_closures = []
|
| 111 |
+
setattr(devctx, closures_type, step_closures)
|
| 112 |
+
step_closures.append(lambda a=args: closure(*a))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def run_step_closures():
|
| 116 |
+
devctx = get_device_context()
|
| 117 |
+
async_step_closures = getattr(devctx, "async_step_closures", None)
|
| 118 |
+
if async_step_closures is not None:
|
| 119 |
+
devctx.async_step_closures = []
|
| 120 |
+
async_closure_handler = getattr(devctx, "async_closure_handler", None)
|
| 121 |
+
if async_closure_handler is None:
|
| 122 |
+
async_closure_handler = AsyncClosureHandler()
|
| 123 |
+
devctx.async_closure_handler = async_closure_handler
|
| 124 |
+
async_closure_handler(async_step_closures)
|
| 125 |
+
|
| 126 |
+
step_closures = getattr(devctx, "step_closures", None)
|
| 127 |
+
if step_closures is not None:
|
| 128 |
+
devctx.step_closures = []
|
| 129 |
+
closure_handler = getattr(devctx, "closure_handler", None)
|
| 130 |
+
if closure_handler is None:
|
| 131 |
+
closure_handler = ClosureHandler()
|
| 132 |
+
devctx.closure_handler = closure_handler
|
| 133 |
+
closure_handler(step_closures)
|
| 134 |
+
return devctx
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy
|
| 2 |
+
import torch._C._lazy_ts_backend
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_tensors_ts_device_data_node(tensors):
|
| 6 |
+
"""Return tensor ids and eager tensors for DeviceData nodes in the
|
| 7 |
+
IR for the passed in lazy tensors.
|
| 8 |
+
|
| 9 |
+
TODO: This API is currently ts backend specific. We are working on
|
| 10 |
+
generalizing it to all backends including XLA.
|
| 11 |
+
"""
|
| 12 |
+
return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_graph_hash(tensors):
|
| 16 |
+
"""Return the graph hash for the passed in lazy tensors"""
|
| 17 |
+
return torch._C._lazy._get_graph_hash(tensors)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run_cached_graph(hash_str, graph_inputs):
|
| 21 |
+
"""Running the cached computation graph with the given inputs
|
| 22 |
+
|
| 23 |
+
TODO: This API is currently ts backend specific. We are working on
|
| 24 |
+
generalizing it to all backends including XLA.
|
| 25 |
+
"""
|
| 26 |
+
return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def render_ir_graph(tensors):
|
| 5 |
+
"""Return a text dump of the LTC IR graph in dot format for the tensors.
|
| 6 |
+
The text can be processed by tools like dot to be rendered in pdf,png etc."""
|
| 7 |
+
return torch._C._lazy._get_tensors_dot(tensors)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def dump_ir(tensors, ir_format):
|
| 11 |
+
"""Return a dump of the tensors in the specified format.
|
| 12 |
+
Valid format are
|
| 13 |
+
- text: for LTC IR
|
| 14 |
+
- backend: for the activate backend IR
|
| 15 |
+
"""
|
| 16 |
+
if ir_format == "text":
|
| 17 |
+
return torch._C._lazy._get_tensors_text(tensors)
|
| 18 |
+
elif ir_format == "backend":
|
| 19 |
+
return torch._C._lazy._get_tensors_backend(tensors)
|
| 20 |
+
else:
|
| 21 |
+
raise RuntimeError(f"Unrecognized IR format: {ir_format}")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch._C._lazy_ts_backend
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def init():
|
| 5 |
+
"""Initializes the lazy Torchscript backend"""
|
| 6 |
+
torch._C._lazy_ts_backend._init()
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (208 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc
ADDED
|
Binary file (29 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
| 2 |
+
from .common import amp_definitely_not_available
|
| 3 |
+
from .grad_scaler import GradScaler
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"amp_definitely_not_available",
|
| 7 |
+
"autocast",
|
| 8 |
+
"custom_bwd",
|
| 9 |
+
"custom_fwd",
|
| 10 |
+
"GradScaler",
|
| 11 |
+
]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (525 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (244 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc
ADDED
|
Binary file (28.8 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc
ADDED
|
Binary file (72.3 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc
ADDED
|
Binary file (52.2 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc
ADDED
|
Binary file (549 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (2.46 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
ADDED
|
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
|
| 2 |
+
op_mod, op_gt, op_lt, op_neq, op_eq
|
| 3 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Constraint:
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Conj(Constraint):
|
| 11 |
+
def __init__(self, conjuncts):
|
| 12 |
+
"""
|
| 13 |
+
:param conjuncts: Conjunction of constraints
|
| 14 |
+
"""
|
| 15 |
+
self.conjucts = conjuncts
|
| 16 |
+
|
| 17 |
+
def __eq__(self, other):
|
| 18 |
+
if isinstance(other, Conj):
|
| 19 |
+
return self.conjucts == other.conjucts and self.conjucts == other.conjucts
|
| 20 |
+
else:
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
def __repr__(self):
|
| 24 |
+
return f'And({self.conjucts})'
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Disj(Constraint):
|
| 28 |
+
def __init__(self, disjuncts):
|
| 29 |
+
"""
|
| 30 |
+
:param disjuncts: Disjunction of constraints
|
| 31 |
+
"""
|
| 32 |
+
self.disjuncts = disjuncts
|
| 33 |
+
|
| 34 |
+
def __eq__(self, other):
|
| 35 |
+
if isinstance(other, Disj):
|
| 36 |
+
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
|
| 37 |
+
else:
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
def __repr__(self):
|
| 41 |
+
return f'Or({self.disjuncts})'
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Prod(Constraint):
|
| 45 |
+
def __init__(self, products):
|
| 46 |
+
"""
|
| 47 |
+
:param products: lists of dimensions to multiply
|
| 48 |
+
"""
|
| 49 |
+
self.products = products
|
| 50 |
+
|
| 51 |
+
def __eq__(self, other):
|
| 52 |
+
if isinstance(other, Prod):
|
| 53 |
+
return self.products == other.products and self.products == other.products
|
| 54 |
+
else:
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
def __repr__(self):
|
| 58 |
+
return f'Product({self.products})'
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class T(Constraint):
|
| 62 |
+
"""
|
| 63 |
+
True
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self):
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def __eq__(self, other):
|
| 69 |
+
return isinstance(other, T)
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
return 'True'
|
| 73 |
+
|
| 74 |
+
class F(Constraint):
|
| 75 |
+
"""
|
| 76 |
+
False
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self):
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
def __eq__(self, other):
|
| 82 |
+
return isinstance(other, F)
|
| 83 |
+
|
| 84 |
+
def __repr__(self):
|
| 85 |
+
return 'False'
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BinaryConstraint(Constraint):
|
| 89 |
+
"""
|
| 90 |
+
Represents all binary operations
|
| 91 |
+
"""
|
| 92 |
+
def __init__(self, lhs, rhs, op):
|
| 93 |
+
"""
|
| 94 |
+
:param lhs: lhs of the constraint
|
| 95 |
+
:param rhs: rhs of the constraint
|
| 96 |
+
:param op: string representing the operation
|
| 97 |
+
"""
|
| 98 |
+
self.lhs = lhs
|
| 99 |
+
self.rhs = rhs
|
| 100 |
+
self.op = op
|
| 101 |
+
|
| 102 |
+
def __eq__(self, other):
|
| 103 |
+
if isinstance(other, BinaryConstraint):
|
| 104 |
+
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
|
| 105 |
+
else:
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
def __repr__(self):
|
| 109 |
+
return f'({self.lhs} {self.op} {self.rhs})'
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class BinConstraintT(BinaryConstraint):
|
| 113 |
+
"""
|
| 114 |
+
Binary constraints about tensors
|
| 115 |
+
"""
|
| 116 |
+
def __init__(self, lhs, rhs, op):
|
| 117 |
+
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
|
| 118 |
+
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
|
| 119 |
+
super().__init__(lhs, rhs, op)
|
| 120 |
+
|
| 121 |
+
def __eq__(self, other):
|
| 122 |
+
return super().__eq__(other)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class BinConstraintD(BinaryConstraint):
|
| 126 |
+
"""
|
| 127 |
+
Binary constraints about dimensions
|
| 128 |
+
"""
|
| 129 |
+
def __init__(self, lhs, rhs, op):
|
| 130 |
+
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
|
| 131 |
+
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
|
| 132 |
+
|
| 133 |
+
super().__init__(lhs, rhs, op)
|
| 134 |
+
|
| 135 |
+
def __eq__(self, other):
|
| 136 |
+
return super().__eq__(other)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class TGreatestUpperBound(Constraint):
|
| 141 |
+
"""
|
| 142 |
+
Greatest Upper bound for tensors with dynamic type
|
| 143 |
+
"""
|
| 144 |
+
def __init__(self, res, rhs1, rhs2):
|
| 145 |
+
"""
|
| 146 |
+
:param res: tensor variable that stores the result of the outout
|
| 147 |
+
:param rhs1: tensor or tensor variable
|
| 148 |
+
:param rhs2: tensor or tensor variabke
|
| 149 |
+
"""
|
| 150 |
+
self.res = res
|
| 151 |
+
self.rhs1 = rhs1
|
| 152 |
+
self.rhs2 = rhs2
|
| 153 |
+
|
| 154 |
+
def __repr__(self):
|
| 155 |
+
return f'{self.res} = {self.rhs1}⊔*{self.rhs2}'
|
| 156 |
+
|
| 157 |
+
def __eq__(self, other):
|
| 158 |
+
if isinstance(other, TGreatestUpperBound):
|
| 159 |
+
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
| 160 |
+
else:
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class DGreatestUpperBound(Constraint):
|
| 165 |
+
"""
|
| 166 |
+
Greatest Upper bound for dimensions
|
| 167 |
+
"""
|
| 168 |
+
def __init__(self, res, rhs1, rhs2):
|
| 169 |
+
"""
|
| 170 |
+
:param res: Dimension variable to store the result
|
| 171 |
+
:param rhs1: dimension variable 1
|
| 172 |
+
:param rhs2: dimension variable 2
|
| 173 |
+
"""
|
| 174 |
+
assert is_dim(res)
|
| 175 |
+
assert is_dim(rhs1)
|
| 176 |
+
assert is_dim(rhs2)
|
| 177 |
+
|
| 178 |
+
self.res = res
|
| 179 |
+
self.rhs1 = rhs1
|
| 180 |
+
self.rhs2 = rhs2
|
| 181 |
+
|
| 182 |
+
def __repr__(self):
|
| 183 |
+
return f'{self.res} = {self.rhs1}⊔{self.rhs2}'
|
| 184 |
+
|
| 185 |
+
def __eq__(self, other):
|
| 186 |
+
if isinstance(other, DGreatestUpperBound):
|
| 187 |
+
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
| 188 |
+
else:
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class CanReshape(Constraint):
|
| 193 |
+
"""
|
| 194 |
+
can_reshape constraint
|
| 195 |
+
"""
|
| 196 |
+
def __init__(self, src, target):
|
| 197 |
+
"""
|
| 198 |
+
:param src: tensor variable
|
| 199 |
+
:param target: tensor
|
| 200 |
+
"""
|
| 201 |
+
self.src = src
|
| 202 |
+
self.target = target
|
| 203 |
+
|
| 204 |
+
def __repr__(self):
|
| 205 |
+
return f'can-reshape({self.src}, {self.target})'
|
| 206 |
+
|
| 207 |
+
def __eq__(self, other):
|
| 208 |
+
if isinstance(other, CanReshape):
|
| 209 |
+
return self.src == other.src and self.target == other.target
|
| 210 |
+
else:
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class IndexSelect(Constraint):
|
| 215 |
+
|
| 216 |
+
def __init__(self, tensor_size, input_var, dim_replace, index, output):
|
| 217 |
+
"""
|
| 218 |
+
Args:
|
| 219 |
+
input_var: input to index_select
|
| 220 |
+
tensor_size: tensor size we are considering
|
| 221 |
+
dim_replace: the dimension of the output at "index"
|
| 222 |
+
index: location of the dimensions to replace in the input
|
| 223 |
+
output: variable to store the result
|
| 224 |
+
"""
|
| 225 |
+
assert isinstance(input_var, TVar)
|
| 226 |
+
assert isinstance(output, TVar)
|
| 227 |
+
assert isinstance(dim_replace, DVar) or dim_replace == Dyn
|
| 228 |
+
assert isinstance(index, int)
|
| 229 |
+
|
| 230 |
+
self.input_var = input_var
|
| 231 |
+
self.tensor_size = tensor_size
|
| 232 |
+
self.dim_replace = dim_replace
|
| 233 |
+
self.index = index
|
| 234 |
+
self.output = output
|
| 235 |
+
|
| 236 |
+
def __repr__(self):
|
| 237 |
+
|
| 238 |
+
return f' {self.output} = ' \
|
| 239 |
+
f'IndexSelect({self.input_var}, ' \
|
| 240 |
+
f'tensor_size: {self.tensor_size}, ' \
|
| 241 |
+
f'{self.dim_replace}, ' \
|
| 242 |
+
f'{self.index})'
|
| 243 |
+
|
| 244 |
+
def __eq__(self, other):
|
| 245 |
+
if isinstance(other, IndexSelect):
|
| 246 |
+
return self.tensor_size == other.tensor_size and \
|
| 247 |
+
self.dim_replace == other.dim_replace and \
|
| 248 |
+
self.index == other.index and \
|
| 249 |
+
self.output == other.output and \
|
| 250 |
+
self.input_var == other.input_var
|
| 251 |
+
else:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Transpose(Constraint):
|
| 256 |
+
|
| 257 |
+
def __init__(self, tensor_size, input_var, index1, index2, output):
|
| 258 |
+
"""
|
| 259 |
+
Args:
|
| 260 |
+
tensor_size: current tensor size
|
| 261 |
+
input_var: variable to hold input
|
| 262 |
+
index1: dimension 1
|
| 263 |
+
index2: dimension 2
|
| 264 |
+
output: output that stores result
|
| 265 |
+
"""
|
| 266 |
+
assert isinstance(input_var, TVar)
|
| 267 |
+
assert isinstance(output, TVar)
|
| 268 |
+
assert isinstance(index1, int)
|
| 269 |
+
assert isinstance(index2, int)
|
| 270 |
+
|
| 271 |
+
self.input_var = input_var
|
| 272 |
+
self.tensor_size = tensor_size
|
| 273 |
+
self.index1 = index1
|
| 274 |
+
self.index2 = index2
|
| 275 |
+
self.output = output
|
| 276 |
+
|
| 277 |
+
def __repr__(self):
|
| 278 |
+
|
| 279 |
+
return f' {self.output} = ' \
|
| 280 |
+
f'Transpose({self.input_var}, ' \
|
| 281 |
+
f'tensor_size: {self.tensor_size}, ' \
|
| 282 |
+
f'{self.index1}, ' \
|
| 283 |
+
f'{self.index2})'
|
| 284 |
+
|
| 285 |
+
def __eq__(self, other):
|
| 286 |
+
if isinstance(other, Transpose):
|
| 287 |
+
return self.tensor_size == other.tensor_size and \
|
| 288 |
+
self.index1 == other.index1 and \
|
| 289 |
+
self.index2 == other.index2 and \
|
| 290 |
+
self.output == other.output and \
|
| 291 |
+
self.input_var == other.input_var
|
| 292 |
+
else:
|
| 293 |
+
return False
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class GetItem(Constraint):
|
| 297 |
+
|
| 298 |
+
def __init__(self, tensor_size, index, res, input_var):
|
| 299 |
+
"""
|
| 300 |
+
Constraint for getting item given a tensor size
|
| 301 |
+
:param tensor_size: actual number
|
| 302 |
+
:param index: actual number representing the index
|
| 303 |
+
:param res: dimension variable to carry the item we get
|
| 304 |
+
:param input_var: a tensor variable from which we will get item
|
| 305 |
+
"""
|
| 306 |
+
assert isinstance(res, DVar)
|
| 307 |
+
|
| 308 |
+
self.res = res
|
| 309 |
+
self.tensor_size = tensor_size
|
| 310 |
+
self.index = index
|
| 311 |
+
self.input_var = input_var
|
| 312 |
+
|
| 313 |
+
def __repr__(self):
|
| 314 |
+
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
|
| 315 |
+
|
| 316 |
+
def __eq__(self, other):
|
| 317 |
+
if isinstance(other, GetItem):
|
| 318 |
+
return self.res == other.res and \
|
| 319 |
+
self.tensor_size == other.tensor_size and \
|
| 320 |
+
self.index == other.index and \
|
| 321 |
+
self.input_var == other.input_var
|
| 322 |
+
else:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
class GetItemTensor(Constraint):
|
| 326 |
+
|
| 327 |
+
def __init__(self, tensor_size, index_tuple, res, input_var):
|
| 328 |
+
"""
|
| 329 |
+
Constraint for getting item given a tensor size
|
| 330 |
+
However, when the argument is a tuple, we will
|
| 331 |
+
expect a tensor
|
| 332 |
+
:param tensor_size: actual number representing the rank
|
| 333 |
+
:param index_tuple: tuple for indexing
|
| 334 |
+
:param res: tensor variable to carry the item we get
|
| 335 |
+
:param input_var: a tensor variable from which we will get item
|
| 336 |
+
"""
|
| 337 |
+
assert isinstance(res, TVar)
|
| 338 |
+
|
| 339 |
+
self.res = res
|
| 340 |
+
self.tensor_size = tensor_size
|
| 341 |
+
self.index_tuple = index_tuple
|
| 342 |
+
self.input_var = input_var
|
| 343 |
+
|
| 344 |
+
def __repr__(self):
|
| 345 |
+
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
|
| 346 |
+
|
| 347 |
+
def __eq__(self, other):
|
| 348 |
+
if isinstance(other, GetItemTensor):
|
| 349 |
+
return self.res == other.res and \
|
| 350 |
+
self.tensor_size == other.tensor_size and \
|
| 351 |
+
self.index_tuple == other.index_tuple and \
|
| 352 |
+
self.input_var == other.input_var
|
| 353 |
+
else:
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
class CalcConv(Constraint):
|
| 357 |
+
|
| 358 |
+
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
|
| 359 |
+
"""
|
| 360 |
+
:param conv_result: the convolution result
|
| 361 |
+
:param input_var: input to convolution
|
| 362 |
+
:param c_out: output chanel type
|
| 363 |
+
:param kernel: kernel tuple
|
| 364 |
+
"""
|
| 365 |
+
self.conv_result = conv_result
|
| 366 |
+
self.input_var = input_var
|
| 367 |
+
self.c_out = c_out
|
| 368 |
+
self.kernel = kernel
|
| 369 |
+
self.padding = padding
|
| 370 |
+
self.stride = stride
|
| 371 |
+
self.dilation = dilation
|
| 372 |
+
self.matching_constraint = matching_constraint_vars
|
| 373 |
+
|
| 374 |
+
def __repr__(self):
|
| 375 |
+
return f'{self.conv_result} =' \
|
| 376 |
+
f' calc-conv({self.input_var},' \
|
| 377 |
+
f' {self.c_out}, {self.kernel}, ' \
|
| 378 |
+
f'{self.padding}, {self.stride},' \
|
| 379 |
+
f' {self.dilation})'
|
| 380 |
+
|
| 381 |
+
def __eq__(self, other):
|
| 382 |
+
if isinstance(other, CalcConv):
|
| 383 |
+
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
|
| 384 |
+
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
|
| 385 |
+
and self.stride == other.stride and self.dilation == other.dilation \
|
| 386 |
+
and self.matching_constraint == other.matching_constraint
|
| 387 |
+
else:
|
| 388 |
+
return False
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class CalcMaxPool(Constraint):
|
| 392 |
+
|
| 393 |
+
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
|
| 394 |
+
"""
|
| 395 |
+
:param maxpool_result: the result of maxpool
|
| 396 |
+
:param input_var: input to convolution
|
| 397 |
+
:param kernel: kernel tuple
|
| 398 |
+
"""
|
| 399 |
+
self.maxpool_result = maxpool_result
|
| 400 |
+
self.input_var = input_var
|
| 401 |
+
self.kernel = kernel
|
| 402 |
+
self.padding = padding
|
| 403 |
+
self.stride = stride
|
| 404 |
+
self.dilation = dilation
|
| 405 |
+
self.matching_constraint = matching_constraint_vars
|
| 406 |
+
|
| 407 |
+
def __repr__(self):
|
| 408 |
+
return f'{self.maxpool_result} =' \
|
| 409 |
+
f' calc-maxpool({self.input_var},' \
|
| 410 |
+
f' {self.kernel}, ' \
|
| 411 |
+
f'{self.padding}, {self.stride},' \
|
| 412 |
+
f' {self.dilation})'
|
| 413 |
+
|
| 414 |
+
def __eq__(self, other):
|
| 415 |
+
if isinstance(other, CalcMaxPool):
|
| 416 |
+
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
|
| 417 |
+
and self.kernel == other.kernel and self.padding == other.padding \
|
| 418 |
+
and self.stride == other.stride and self.dilation == other.dilation \
|
| 419 |
+
and self.matching_constraint == other.matching_constraint
|
| 420 |
+
else:
|
| 421 |
+
return False
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
class ApplyBroadcasting(Constraint):
|
| 425 |
+
def __init__(self, res1, res2, input1, input2):
|
| 426 |
+
"""
|
| 427 |
+
:param res1: resulting tensor 1
|
| 428 |
+
:param res2: resulting tensor 2
|
| 429 |
+
:param input1: tensor variable 1
|
| 430 |
+
:param input2: tensor variable 2
|
| 431 |
+
"""
|
| 432 |
+
self.res1 = res1
|
| 433 |
+
self.res2 = res2
|
| 434 |
+
self.input1 = input1
|
| 435 |
+
self.input2 = input2
|
| 436 |
+
|
| 437 |
+
def __eq__(self, other):
|
| 438 |
+
if isinstance(other, ApplyBroadcasting):
|
| 439 |
+
return self.res1 == other.res1 \
|
| 440 |
+
and self.res2 == other.res2 \
|
| 441 |
+
and self.input1 == other.input1 \
|
| 442 |
+
and self.input2 == other.input2
|
| 443 |
+
else:
|
| 444 |
+
return False
|
| 445 |
+
|
| 446 |
+
def __repr__(self):
|
| 447 |
+
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class CalcProduct(Constraint):
|
| 451 |
+
"""
|
| 452 |
+
Given correct dimensions, calculate the product for flatten accounting for Dyn
|
| 453 |
+
"""
|
| 454 |
+
def __init__(self, start, end, flattened, dims_to_flatten):
|
| 455 |
+
"""
|
| 456 |
+
:param start: start index
|
| 457 |
+
:param end: end index
|
| 458 |
+
:param flattened: variable to store the product
|
| 459 |
+
:param dims_to_flatten: the type which we will flatten
|
| 460 |
+
"""
|
| 461 |
+
assert isinstance(dims_to_flatten, list)
|
| 462 |
+
assert isinstance(flattened, TVar)
|
| 463 |
+
assert isinstance(start, int)
|
| 464 |
+
assert isinstance(end, int)
|
| 465 |
+
|
| 466 |
+
self.start = start
|
| 467 |
+
self.end = end
|
| 468 |
+
self.dims_to_flatten = dims_to_flatten
|
| 469 |
+
self.flattened = flattened
|
| 470 |
+
|
| 471 |
+
def __eq__(self, other):
|
| 472 |
+
if isinstance(other, CalcProduct):
|
| 473 |
+
return self.start == other.start and self.end == other.end and \
|
| 474 |
+
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
|
| 475 |
+
|
| 476 |
+
else:
|
| 477 |
+
return False
|
| 478 |
+
|
| 479 |
+
def __repr__(self):
|
| 480 |
+
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class TVar:
|
| 484 |
+
"""
|
| 485 |
+
Tensor variable with no tensor constructor
|
| 486 |
+
"""
|
| 487 |
+
def __init__(self, tvar):
|
| 488 |
+
"""
|
| 489 |
+
:param tvar: tensor variable
|
| 490 |
+
"""
|
| 491 |
+
self.tvar = tvar
|
| 492 |
+
|
| 493 |
+
def __repr__(self):
|
| 494 |
+
return f'TV({self.tvar})'
|
| 495 |
+
|
| 496 |
+
def __eq__(self, other):
|
| 497 |
+
if isinstance(other, TVar):
|
| 498 |
+
return self.tvar == other.tvar
|
| 499 |
+
else:
|
| 500 |
+
return False
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class DVar:
|
| 504 |
+
"""
|
| 505 |
+
Dimension variable
|
| 506 |
+
"""
|
| 507 |
+
def __init__(self, c):
|
| 508 |
+
"""
|
| 509 |
+
:param c: character or number
|
| 510 |
+
"""
|
| 511 |
+
self.c = c
|
| 512 |
+
|
| 513 |
+
def __repr__(self):
|
| 514 |
+
return f'DV({self.c})'
|
| 515 |
+
|
| 516 |
+
def __eq__(self, other):
|
| 517 |
+
if isinstance(other, DVar):
|
| 518 |
+
return self.c == other.c
|
| 519 |
+
else:
|
| 520 |
+
return False
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class BVar:
|
| 524 |
+
"""
|
| 525 |
+
Boolean variable
|
| 526 |
+
"""
|
| 527 |
+
def __init__(self, c):
|
| 528 |
+
"""
|
| 529 |
+
:param c: character or number
|
| 530 |
+
"""
|
| 531 |
+
self.c = c
|
| 532 |
+
|
| 533 |
+
def __repr__(self):
|
| 534 |
+
return f'BV({self.c})'
|
| 535 |
+
|
| 536 |
+
def __eq__(self, other):
|
| 537 |
+
if isinstance(other, BVar):
|
| 538 |
+
return self.c == other.c
|
| 539 |
+
else:
|
| 540 |
+
return False
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def is_algebraic_expression(constraint):
|
| 544 |
+
if isinstance(constraint, BinConstraintD):
|
| 545 |
+
return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
|
| 546 |
+
else:
|
| 547 |
+
return isinstance(constraint, Prod)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def is_bool_expr(constraint):
|
| 551 |
+
if isinstance(constraint, BinConstraintD):
|
| 552 |
+
return constraint.op in [op_gt, op_lt, op_neq, op_eq]
|
| 553 |
+
else:
|
| 554 |
+
return isinstance(constraint, (BVar, Conj, Disj))
|
| 555 |
+
|
| 556 |
+
def is_dim(d):
|
| 557 |
+
return isinstance(d, (DVar, int)) or d == Dyn
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
ADDED
|
@@ -0,0 +1,1279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import operator
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Callable, Dict, Iterable
|
| 5 |
+
|
| 6 |
+
from torch.fx._symbolic_trace import _assert_is_none
|
| 7 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
|
| 8 |
+
Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
|
| 9 |
+
TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
|
| 10 |
+
from torch.fx.experimental.migrate_gradual_types.operation import \
|
| 11 |
+
op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
|
| 12 |
+
from torch.fx.node import Target, Node
|
| 13 |
+
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
|
| 14 |
+
gen_bvar
|
| 15 |
+
|
| 16 |
+
from torch.fx.tensor_type import Dyn, TensorType
|
| 17 |
+
from torch.nn.modules.conv import Conv2d
|
| 18 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
| 19 |
+
|
| 20 |
+
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
| 21 |
+
|
| 22 |
+
MAX_TENSOR_RANK = 4
|
| 23 |
+
|
| 24 |
+
def register_inference_rule(call_target):
|
| 25 |
+
def register(fn):
|
| 26 |
+
if call_target in _INFERENCE_RULES:
|
| 27 |
+
raise RuntimeError(f'Inference rule already registered for {call_target}!')
|
| 28 |
+
_INFERENCE_RULES[call_target] = fn
|
| 29 |
+
return fn
|
| 30 |
+
return register
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
|
| 34 |
+
d, counter = gen_tensor_dims(n, counter)
|
| 35 |
+
c1 = BinConstraintT(input, TensorType(d), op_eq)
|
| 36 |
+
start_dim = n if start_dim == -1 else abs(start_dim)
|
| 37 |
+
end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
|
| 38 |
+
c2 = CalcProduct(start_dim, end_dim, flattened, d)
|
| 39 |
+
nat_constraints = gen_nat_constraints(d)
|
| 40 |
+
return Conj([c1, c2, *nat_constraints]), counter
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@register_inference_rule(getattr)
|
| 44 |
+
def get_attr_inference_rule(n: Node, symbols, constraints, counter):
|
| 45 |
+
"""
|
| 46 |
+
If the attribute is "device" then the tensor shape is preserved
|
| 47 |
+
"""
|
| 48 |
+
assert isinstance(n.args[0], Node)
|
| 49 |
+
assert isinstance(n.args[1], str)
|
| 50 |
+
output, counter = gen_tvar(counter)
|
| 51 |
+
symbols[n] = output
|
| 52 |
+
|
| 53 |
+
input = symbols[n.args[0]]
|
| 54 |
+
attr = n.args[1]
|
| 55 |
+
|
| 56 |
+
if attr == 'device':
|
| 57 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 58 |
+
else:
|
| 59 |
+
raise NotImplementedError('Not yet implemented')
|
| 60 |
+
|
| 61 |
+
@register_inference_rule(torch.bmm)
|
| 62 |
+
def bmm_inference_rule(n: Node, symbols, constraints, counter):
|
| 63 |
+
"""
|
| 64 |
+
Constraints that match the input to a size 3 tensor
|
| 65 |
+
and switch the dimensions according to the rules
|
| 66 |
+
of batch multiplication
|
| 67 |
+
"""
|
| 68 |
+
assert isinstance(n.args[0], Node)
|
| 69 |
+
assert isinstance(n.args[1], Node)
|
| 70 |
+
|
| 71 |
+
bmm_output, counter = gen_tvar(counter)
|
| 72 |
+
symbols[n] = bmm_output
|
| 73 |
+
|
| 74 |
+
bmm_input1 = symbols[n.args[0]]
|
| 75 |
+
bmm_input2 = symbols[n.args[1]]
|
| 76 |
+
|
| 77 |
+
dims_input1, counter = gen_tensor_dims(3, counter)
|
| 78 |
+
dims_input2, counter = gen_tensor_dims(3, counter)
|
| 79 |
+
|
| 80 |
+
inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
| 81 |
+
BinConstraintT(bmm_input2, Dyn, op_eq),
|
| 82 |
+
BinConstraintT(bmm_output, Dyn, op_eq)])
|
| 83 |
+
|
| 84 |
+
input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
| 85 |
+
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
| 86 |
+
BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
|
| 87 |
+
|
| 88 |
+
input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
|
| 89 |
+
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
| 90 |
+
BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
|
| 91 |
+
|
| 92 |
+
consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
|
| 93 |
+
|
| 94 |
+
batch_size, counter = gen_dvar(counter)
|
| 95 |
+
|
| 96 |
+
inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
| 97 |
+
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
| 98 |
+
BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
|
| 99 |
+
*consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
|
| 100 |
+
|
| 101 |
+
return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@register_inference_rule("index_select")
|
| 105 |
+
def index_select_inference_rule(n: Node, symbols, constraints, counter):
|
| 106 |
+
"""
|
| 107 |
+
We constrain the second argument to a vector or Dyn.
|
| 108 |
+
The output replaces the input with the shape of the vector
|
| 109 |
+
at the position given by the index (first argument)
|
| 110 |
+
"""
|
| 111 |
+
# print(n.args)
|
| 112 |
+
assert isinstance(n.args[0], Node)
|
| 113 |
+
assert isinstance(n.args[1], int)
|
| 114 |
+
assert isinstance(n.args[2], Node)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
index_select, counter = gen_tvar(counter)
|
| 119 |
+
symbols[n] = index_select
|
| 120 |
+
|
| 121 |
+
dims, counter = gen_tensor_dims(1, counter)
|
| 122 |
+
|
| 123 |
+
# equality constraint
|
| 124 |
+
is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
|
| 125 |
+
is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
|
| 126 |
+
|
| 127 |
+
c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
|
| 128 |
+
for i in range(MAX_TENSOR_RANK)])])
|
| 129 |
+
c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
|
| 130 |
+
for i in range(MAX_TENSOR_RANK)])])
|
| 131 |
+
|
| 132 |
+
return [Disj([c2, c3])], counter
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@register_inference_rule("expand")
|
| 136 |
+
def expand_inference_rule(n: Node, symbols, constraints, counter):
|
| 137 |
+
"""
|
| 138 |
+
We generate the exact constraints as we do for tensor additions but we constraint
|
| 139 |
+
the rank of this expression to be equal to len(n.args[1:]) so that only
|
| 140 |
+
those cases get considered for the output
|
| 141 |
+
"""
|
| 142 |
+
assert isinstance(n.args[0], Node)
|
| 143 |
+
|
| 144 |
+
# define the output for expand
|
| 145 |
+
expand, counter = gen_tvar(counter)
|
| 146 |
+
symbols[n] = expand
|
| 147 |
+
|
| 148 |
+
# since we do not have two nodes here, we will construct an argument variable
|
| 149 |
+
e1 = symbols[n.args[0]]
|
| 150 |
+
e2, counter = gen_tvar(counter)
|
| 151 |
+
|
| 152 |
+
e2_nat_constraints = []
|
| 153 |
+
for arg in n.args[1:]:
|
| 154 |
+
assert isinstance(arg, (Node, int))
|
| 155 |
+
if isinstance(arg, Node):
|
| 156 |
+
assert isinstance(symbols[arg], DVar)
|
| 157 |
+
e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
|
| 158 |
+
|
| 159 |
+
e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
|
| 160 |
+
|
| 161 |
+
constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
|
| 162 |
+
|
| 163 |
+
# constraint the output size
|
| 164 |
+
dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
|
| 165 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 166 |
+
c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
|
| 167 |
+
constraints += c
|
| 168 |
+
|
| 169 |
+
return constraints, counter
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@register_inference_rule(torch.nn.functional.gelu)
|
| 173 |
+
@register_inference_rule(torch.nn.functional.dropout)
|
| 174 |
+
@register_inference_rule(torch.nn.functional.softmax)
|
| 175 |
+
@register_inference_rule("detach")
|
| 176 |
+
@register_inference_rule("to")
|
| 177 |
+
@register_inference_rule("int")
|
| 178 |
+
@register_inference_rule("long")
|
| 179 |
+
@register_inference_rule("contiguous")
|
| 180 |
+
@register_inference_rule(torch.ones)
|
| 181 |
+
@register_inference_rule(torch.zeros)
|
| 182 |
+
def equality_inference_rule(n: Node, symbols, constraints, counter):
|
| 183 |
+
"""
|
| 184 |
+
We generate the constraint: input = output
|
| 185 |
+
"""
|
| 186 |
+
output, counter = gen_tvar(counter)
|
| 187 |
+
symbols[n] = output
|
| 188 |
+
|
| 189 |
+
if isinstance(n.args[0], Node):
|
| 190 |
+
input = symbols[n.args[0]]
|
| 191 |
+
if isinstance(input, TVar):
|
| 192 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 193 |
+
|
| 194 |
+
# then we have dimension variables
|
| 195 |
+
else:
|
| 196 |
+
for arg in n.args:
|
| 197 |
+
assert isinstance(symbols[arg], DVar)
|
| 198 |
+
my_size = [symbols[arg] for arg in n.args]
|
| 199 |
+
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
|
| 200 |
+
|
| 201 |
+
elif isinstance(n.args[0], tuple):
|
| 202 |
+
# then the tuple is the size
|
| 203 |
+
assert len(n.args[0]) <= 4
|
| 204 |
+
my_size = [symbols[arg] for arg in n.args[0]]
|
| 205 |
+
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
|
| 206 |
+
else:
|
| 207 |
+
raise NotImplementedError('Method not yet implemented')
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@register_inference_rule("transpose")
|
| 211 |
+
def transpose_inference_rule(n: Node, symbols, constraints, counter):
|
| 212 |
+
"""
|
| 213 |
+
Can be considered as a sequence of two index selects, so we generate constraints accordingly
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(n.args[0], Node)
|
| 216 |
+
assert isinstance(n.args[1], int)
|
| 217 |
+
assert isinstance(n.args[2], int)
|
| 218 |
+
|
| 219 |
+
output, counter = gen_tvar(counter)
|
| 220 |
+
symbols[n] = output
|
| 221 |
+
|
| 222 |
+
from_arg = symbols[n.args[0]]
|
| 223 |
+
assert isinstance(from_arg, TVar)
|
| 224 |
+
|
| 225 |
+
# input and output are dyn
|
| 226 |
+
is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
|
| 227 |
+
|
| 228 |
+
# or input is a tensor and we actually do the replacement
|
| 229 |
+
c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
|
| 230 |
+
|
| 231 |
+
return [Disj([is_dyn, c3])], counter
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@register_inference_rule("type_as")
|
| 235 |
+
def type_inference_rule(n: Node, symbols, constraints, counter):
|
| 236 |
+
"""
|
| 237 |
+
We generate the constraint: input = output
|
| 238 |
+
"""
|
| 239 |
+
assert isinstance(n.args[0], Node)
|
| 240 |
+
assert isinstance(n.args[1], Node)
|
| 241 |
+
|
| 242 |
+
output, counter = gen_tvar(counter)
|
| 243 |
+
symbols[n] = output
|
| 244 |
+
|
| 245 |
+
from_arg = symbols[n.args[0]]
|
| 246 |
+
to_arg = symbols[n.args[1]]
|
| 247 |
+
|
| 248 |
+
assert isinstance(from_arg, TVar)
|
| 249 |
+
assert isinstance(to_arg, TVar)
|
| 250 |
+
|
| 251 |
+
return [BinConstraintT(from_arg, to_arg, op_consistency),
|
| 252 |
+
BinConstraintT(output, to_arg, op_eq)], counter
|
| 253 |
+
|
| 254 |
+
@register_inference_rule("masked_fill_")
|
| 255 |
+
def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
|
| 256 |
+
"""
|
| 257 |
+
Similar to addition. For now we implement the constraints when
|
| 258 |
+
the argument is a boolean tensor. There is also a case for when
|
| 259 |
+
it is a condition. We will leave this out for now.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
assert isinstance(n.args[0], Node)
|
| 263 |
+
assert isinstance(n.args[1], Node)
|
| 264 |
+
|
| 265 |
+
# We will retrieve the type variables from the symbol table
|
| 266 |
+
# and confirm they are tensor variables
|
| 267 |
+
|
| 268 |
+
e1 = symbols[n.args[0]]
|
| 269 |
+
e2 = symbols[n.args[1]]
|
| 270 |
+
|
| 271 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 272 |
+
masked_fill_tensor, counter = gen_tvar(counter)
|
| 273 |
+
symbols[n] = masked_fill_tensor
|
| 274 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
|
| 275 |
+
else:
|
| 276 |
+
raise NotImplementedError('Not yet implemented')
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@register_inference_rule(torch.nn.functional.embedding)
|
| 280 |
+
def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
|
| 281 |
+
assert isinstance(n.args[0], Node)
|
| 282 |
+
|
| 283 |
+
embedding_dim_weights = symbols[n.args[1]]
|
| 284 |
+
|
| 285 |
+
# will treat this as a static shape. So we will not use matching.
|
| 286 |
+
weight_dims, counter = gen_tensor_dims(2, counter)
|
| 287 |
+
equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
|
| 288 |
+
embedding_dim = weight_dims[1]
|
| 289 |
+
constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
|
| 290 |
+
return [equality_constraint] + constraints, counter
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
@register_inference_rule(torch.nn.modules.sparse.Embedding)
|
| 294 |
+
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 295 |
+
"""
|
| 296 |
+
The output shape differs from the input shape in the last dimension
|
| 297 |
+
"""
|
| 298 |
+
assert isinstance(n.args[0], Node)
|
| 299 |
+
return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
|
| 303 |
+
|
| 304 |
+
embedding_output, counter = gen_tvar(counter)
|
| 305 |
+
symbols[n] = embedding_output
|
| 306 |
+
embedding_input = symbols[n.args[0]]
|
| 307 |
+
|
| 308 |
+
input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
|
| 309 |
+
output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
|
| 310 |
+
|
| 311 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 312 |
+
c2 = []
|
| 313 |
+
|
| 314 |
+
for i in range(1, MAX_TENSOR_RANK):
|
| 315 |
+
new_dims, counter = gen_tensor_dims(i, counter)
|
| 316 |
+
nat_constraints = gen_nat_constraints(new_dims)
|
| 317 |
+
|
| 318 |
+
# we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
|
| 319 |
+
c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
|
| 320 |
+
BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
|
| 321 |
+
nat_constraints)
|
| 322 |
+
c2.append(c_tensor_i)
|
| 323 |
+
|
| 324 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@register_inference_rule(torch.tensor)
|
| 328 |
+
def tensor_inference_rule(n: Node, symbols, constraints, counter):
|
| 329 |
+
"""
|
| 330 |
+
If the tensor is a scalar, we will skip it since we
|
| 331 |
+
do not support scalars yet. We will add support in the future
|
| 332 |
+
if it's needed. For our examples so far, scalars are not needed.
|
| 333 |
+
"""
|
| 334 |
+
return [], counter
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@register_inference_rule("reshape")
|
| 338 |
+
@register_inference_rule("view")
|
| 339 |
+
def view_inference_rule(n: Node, symbols, constraints, counter):
|
| 340 |
+
"""
|
| 341 |
+
Similar to reshape but with an extra condition on the strides
|
| 342 |
+
"""
|
| 343 |
+
assert isinstance(n.args[0], Node)
|
| 344 |
+
|
| 345 |
+
# generate the new variable
|
| 346 |
+
my_view, counter = gen_tvar(counter)
|
| 347 |
+
symbols[n] = my_view
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
src_var = symbols[n.args[0]]
|
| 351 |
+
t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
|
| 352 |
+
t2_type = []
|
| 353 |
+
num_constraints = []
|
| 354 |
+
|
| 355 |
+
for t in t2:
|
| 356 |
+
if t == -1:
|
| 357 |
+
var, counter = gen_dvar(counter)
|
| 358 |
+
t2_type.append(var)
|
| 359 |
+
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
|
| 360 |
+
|
| 361 |
+
else:
|
| 362 |
+
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
|
| 363 |
+
t2_type.append(t)
|
| 364 |
+
|
| 365 |
+
t2_type = TensorType(t2_type) # type: ignore[assignment]
|
| 366 |
+
|
| 367 |
+
c1 = BinConstraintT(my_view, t2_type, op_eq)
|
| 368 |
+
c2 = CanReshape(src_var, t2_type)
|
| 369 |
+
|
| 370 |
+
# TODO: add the extra check mentioned here:
|
| 371 |
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
|
| 372 |
+
|
| 373 |
+
return [c1, c2] + num_constraints, counter # type: ignore[operator]
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@register_inference_rule("size")
|
| 377 |
+
def size_inference_rule(n: Node, symbols, constraints, counter):
|
| 378 |
+
"""
|
| 379 |
+
The constraint is just lhs = rhs.
|
| 380 |
+
Ex: size = input_ids.size()
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if len(n.args) == 1:
|
| 385 |
+
# generate the new variable
|
| 386 |
+
size, counter = gen_tvar(counter)
|
| 387 |
+
symbols[n] = size
|
| 388 |
+
input = symbols[n.args[0]]
|
| 389 |
+
c = BinConstraintT(input, size, op_eq)
|
| 390 |
+
return [c], counter
|
| 391 |
+
|
| 392 |
+
elif len(n.args) == 2:
|
| 393 |
+
# TODO: review this rule; should input = dyn; output = dyn be included here?
|
| 394 |
+
if isinstance(n.args[1], int):
|
| 395 |
+
# generate the new variable
|
| 396 |
+
size_index, counter = gen_dvar(counter)
|
| 397 |
+
symbols[n] = size_index
|
| 398 |
+
input = symbols[n.args[0]]
|
| 399 |
+
c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
|
| 400 |
+
c3 = BinConstraintD(0, size_index, op_leq)
|
| 401 |
+
|
| 402 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 403 |
+
output_dyn = BinConstraintD(size_index, Dyn, op_eq)
|
| 404 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 405 |
+
|
| 406 |
+
return [Disj([c1, Conj([Disj(c2), c3])])], counter
|
| 407 |
+
|
| 408 |
+
else:
|
| 409 |
+
raise NotImplementedError
|
| 410 |
+
|
| 411 |
+
else:
|
| 412 |
+
raise NotImplementedError
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def range_check(i, n):
|
| 416 |
+
"""
|
| 417 |
+
Checks if an index i is within range of a size n list
|
| 418 |
+
Args:
|
| 419 |
+
i: index
|
| 420 |
+
n: list size
|
| 421 |
+
|
| 422 |
+
Returns: Boolean
|
| 423 |
+
"""
|
| 424 |
+
if i >= 0:
|
| 425 |
+
return T() if i < n else F()
|
| 426 |
+
else:
|
| 427 |
+
return T() if i >= n else F()
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@register_inference_rule(torch.cumsum)
|
| 431 |
+
def cumsum_inference_rule(n: Node, symbols, constraints, counter):
|
| 432 |
+
"""
|
| 433 |
+
Input and output shapes should be equal
|
| 434 |
+
We should verify that the index is valid
|
| 435 |
+
"""
|
| 436 |
+
assert isinstance(n.args[0], Node)
|
| 437 |
+
arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
|
| 438 |
+
assert isinstance(arg_1, int)
|
| 439 |
+
|
| 440 |
+
output, counter = gen_tvar(counter)
|
| 441 |
+
symbols[n] = output
|
| 442 |
+
input = symbols[n.args[0]]
|
| 443 |
+
|
| 444 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 445 |
+
output_dyn = BinConstraintT(output, Dyn, op_eq)
|
| 446 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 447 |
+
c2 = []
|
| 448 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 449 |
+
new_dims, counter = gen_tensor_dims(i, counter)
|
| 450 |
+
|
| 451 |
+
nat_constraints = gen_nat_constraints(new_dims)
|
| 452 |
+
|
| 453 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
|
| 454 |
+
BinConstraintT(output, TensorType(new_dims), op_eq)] +
|
| 455 |
+
[range_check(arg_1, i)] + nat_constraints)
|
| 456 |
+
|
| 457 |
+
c2.append(c_tensor_i)
|
| 458 |
+
dyn_or_tensor = Disj([c1, Disj(c2)])
|
| 459 |
+
return [dyn_or_tensor], counter
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@register_inference_rule(_assert_is_none)
|
| 463 |
+
def assert_inference_rule(n: Node, symbols, constraints, counter):
|
| 464 |
+
assert len(n.users) == 0
|
| 465 |
+
return [], counter
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@register_inference_rule(operator.getitem)
|
| 469 |
+
def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
| 470 |
+
assert isinstance(n.args[0], Node)
|
| 471 |
+
|
| 472 |
+
# dimension output case
|
| 473 |
+
if isinstance(n.args[1], int):
|
| 474 |
+
# create and store the new dimension variable
|
| 475 |
+
get_item_output, counter = gen_dvar(counter)
|
| 476 |
+
symbols[n] = get_item_output
|
| 477 |
+
|
| 478 |
+
# retrieve arg variables
|
| 479 |
+
get_item_arg = symbols[n.args[0]]
|
| 480 |
+
assert isinstance(get_item_arg, TVar)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# if the input is dynamic, we accept any index and return
|
| 484 |
+
# a dynamic dimension as output
|
| 485 |
+
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
|
| 486 |
+
output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
|
| 487 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 488 |
+
|
| 489 |
+
# if the input is a tensor,
|
| 490 |
+
# generate a getItem constraint which will be expanded based on the
|
| 491 |
+
# tensor dimension.
|
| 492 |
+
|
| 493 |
+
c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# since the output is a dimension, we make sure it's a natural number
|
| 497 |
+
# added as a conjunction to the disjunction of c2
|
| 498 |
+
c3 = BinConstraintD(0, get_item_output, op_leq)
|
| 499 |
+
return [Disj([c1, Conj([Disj(c2), c3])])], counter
|
| 500 |
+
|
| 501 |
+
# tensor output case
|
| 502 |
+
elif isinstance(n.args[1], tuple):
|
| 503 |
+
# create and store the new tensor variable
|
| 504 |
+
get_item_output, counter = gen_tvar(counter)
|
| 505 |
+
symbols[n] = get_item_output
|
| 506 |
+
|
| 507 |
+
# retrieve arg variables
|
| 508 |
+
if n.args[0] in symbols:
|
| 509 |
+
get_item_arg = symbols[n.args[0]]
|
| 510 |
+
assert isinstance(get_item_arg, TVar)
|
| 511 |
+
|
| 512 |
+
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
|
| 513 |
+
output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
|
| 514 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 515 |
+
|
| 516 |
+
c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
|
| 517 |
+
for i in range(MAX_TENSOR_RANK)]
|
| 518 |
+
else:
|
| 519 |
+
# TODO: we should figure out why there is a key-error here.
|
| 520 |
+
return [], counter
|
| 521 |
+
|
| 522 |
+
return [Disj([c1, *c2])], counter
|
| 523 |
+
|
| 524 |
+
else:
|
| 525 |
+
raise RuntimeError('Method not yet implemented')
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
@register_inference_rule(operator.gt)
|
| 529 |
+
def gt_inference_rule(n: Node, symbols, constraints, counter):
|
| 530 |
+
assert isinstance(n.args[0], (Node, int))
|
| 531 |
+
assert isinstance(n.args[1], (Node, int))
|
| 532 |
+
|
| 533 |
+
# We make sure this node will not be used again. We do not
|
| 534 |
+
# generate a constraint about that node. Only about the operands.
|
| 535 |
+
|
| 536 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 537 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 538 |
+
|
| 539 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 540 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 541 |
+
gt_tensor, counter = gen_tvar(counter)
|
| 542 |
+
symbols[n] = gt_tensor
|
| 543 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
|
| 544 |
+
|
| 545 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 546 |
+
# This is meant to be used for flow analysis only
|
| 547 |
+
gt_constraint = BinConstraintD(e1, e2, op_gt)
|
| 548 |
+
|
| 549 |
+
my_gt, counter = gen_bvar(counter)
|
| 550 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 551 |
+
return [equality_constraint], counter
|
| 552 |
+
|
| 553 |
+
else:
|
| 554 |
+
raise RuntimeError('Sort Mismatch')
|
| 555 |
+
|
| 556 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 557 |
+
if isinstance(e1, DVar):
|
| 558 |
+
# This is meant to be used for flow analysis only
|
| 559 |
+
gt_constraint = BinConstraintD(e1, e2, op_gt)
|
| 560 |
+
|
| 561 |
+
my_gt, counter = gen_bvar(counter)
|
| 562 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 563 |
+
return [equality_constraint], counter
|
| 564 |
+
|
| 565 |
+
elif isinstance(e1, TVar) and isinstance(e2, int):
|
| 566 |
+
# then we made the wrong assumption about the argument being a tensor
|
| 567 |
+
# so we should fix the assumption
|
| 568 |
+
warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
|
| 569 |
+
|
| 570 |
+
new_e1, counter = gen_dvar(counter)
|
| 571 |
+
symbols[n.args[0]] = new_e1
|
| 572 |
+
symbols[n.args[0]]
|
| 573 |
+
|
| 574 |
+
gt_constraint = BinConstraintD(new_e1, e2, op_gt)
|
| 575 |
+
|
| 576 |
+
my_gt, counter = gen_bvar(counter)
|
| 577 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 578 |
+
return [equality_constraint], counter
|
| 579 |
+
|
| 580 |
+
else:
|
| 581 |
+
raise NotImplementedError('Method not yet implemented')
|
| 582 |
+
|
| 583 |
+
else:
|
| 584 |
+
raise NotImplementedError('Method not yet implemented')
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
@register_inference_rule(operator.eq)
|
| 588 |
+
def eq_inference_rule(n: Node, symbols, constraints, counter):
|
| 589 |
+
assert isinstance(n.args[0], (Node, int))
|
| 590 |
+
assert isinstance(n.args[1], (Node, int))
|
| 591 |
+
|
| 592 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 593 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 594 |
+
|
| 595 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 596 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 597 |
+
eq_tensor, counter = gen_tvar(counter)
|
| 598 |
+
symbols[n] = eq_tensor
|
| 599 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
|
| 600 |
+
|
| 601 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 602 |
+
# This is meant to be used for flow analysis only
|
| 603 |
+
eq_constraint = BinConstraintD(e1, e2, op_eq)
|
| 604 |
+
|
| 605 |
+
my_eq, counter = gen_bvar(counter)
|
| 606 |
+
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
|
| 607 |
+
return [equality_constraint], counter
|
| 608 |
+
|
| 609 |
+
else:
|
| 610 |
+
raise RuntimeError('Sort Mismatch')
|
| 611 |
+
|
| 612 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 613 |
+
if isinstance(e1, DVar):
|
| 614 |
+
# This is meant to be used for flow analysis only
|
| 615 |
+
eq_constraint = BinConstraintD(e1, e2, op_eq)
|
| 616 |
+
|
| 617 |
+
my_eq, counter = gen_bvar(counter)
|
| 618 |
+
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
|
| 619 |
+
return [equality_constraint], counter
|
| 620 |
+
else:
|
| 621 |
+
raise NotImplementedError('Method not yet implemented')
|
| 622 |
+
else:
|
| 623 |
+
raise NotImplementedError('Method not yet implemented')
|
| 624 |
+
|
| 625 |
+
@register_inference_rule(operator.ne)
|
| 626 |
+
def neq_inference_rule(n: Node, symbols, constraints, counter):
|
| 627 |
+
"""
|
| 628 |
+
Translates to inconsistent in gradual types.
|
| 629 |
+
To prove inequality, we should prove that
|
| 630 |
+
tensors are either different sizes or
|
| 631 |
+
disagree on at least one dimension
|
| 632 |
+
|
| 633 |
+
This is a WIP (works when the condition
|
| 634 |
+
is false. We are working on making this operation work
|
| 635 |
+
when the condition is true as well)
|
| 636 |
+
"""
|
| 637 |
+
assert isinstance(n.args[0], Node)
|
| 638 |
+
assert isinstance(n.args[1], tuple)
|
| 639 |
+
|
| 640 |
+
# implementing for size 3 and 4
|
| 641 |
+
if len(n.args[1]) == 3:
|
| 642 |
+
|
| 643 |
+
assert isinstance(n.args[1][0], (Node, int))
|
| 644 |
+
assert isinstance(n.args[1][1], (Node, int))
|
| 645 |
+
assert isinstance(n.args[1][2], (Node, int))
|
| 646 |
+
|
| 647 |
+
lhs = symbols[n.args[0]]
|
| 648 |
+
|
| 649 |
+
b, counter = gen_tensor_dims(4, counter)
|
| 650 |
+
input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
|
| 651 |
+
|
| 652 |
+
d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
|
| 653 |
+
d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
|
| 654 |
+
d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
|
| 655 |
+
|
| 656 |
+
# dimensions not equal
|
| 657 |
+
my_ne, counter = gen_bvar(counter)
|
| 658 |
+
neq_1 = BinConstraintD(d1, b[0], op_neq)
|
| 659 |
+
neq_2 = BinConstraintD(d2, b[1], op_neq)
|
| 660 |
+
neq_3 = BinConstraintD(d3, b[2], op_neq)
|
| 661 |
+
|
| 662 |
+
# dimensions inconsistent
|
| 663 |
+
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
|
| 664 |
+
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
|
| 665 |
+
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
|
| 666 |
+
|
| 667 |
+
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
|
| 668 |
+
|
| 669 |
+
# we are covering size 3 and 4 only for now
|
| 670 |
+
ne_constraint = Conj([input_is_size3, dims_inconsistent])
|
| 671 |
+
|
| 672 |
+
my_ne, counter = gen_bvar(counter)
|
| 673 |
+
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
| 674 |
+
|
| 675 |
+
elif len(n.args[1]) == 4:
|
| 676 |
+
|
| 677 |
+
assert isinstance(n.args[1][0], (Node, int))
|
| 678 |
+
assert isinstance(n.args[1][1], (Node, int))
|
| 679 |
+
assert isinstance(n.args[1][2], (Node, int))
|
| 680 |
+
assert isinstance(n.args[1][3], (Node, int))
|
| 681 |
+
|
| 682 |
+
lhs = symbols[n.args[0]]
|
| 683 |
+
|
| 684 |
+
b1, counter = gen_dvar(counter)
|
| 685 |
+
b2, counter = gen_dvar(counter)
|
| 686 |
+
b3, counter = gen_dvar(counter)
|
| 687 |
+
b4, counter = gen_dvar(counter)
|
| 688 |
+
|
| 689 |
+
input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
|
| 690 |
+
|
| 691 |
+
d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
|
| 692 |
+
d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
|
| 693 |
+
d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
|
| 694 |
+
d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
|
| 695 |
+
|
| 696 |
+
# dimensions not equal
|
| 697 |
+
my_ne, counter = gen_bvar(counter)
|
| 698 |
+
neq_1 = BinConstraintD(d1, b1, op_neq)
|
| 699 |
+
neq_2 = BinConstraintD(d2, b2, op_neq)
|
| 700 |
+
neq_3 = BinConstraintD(d3, b3, op_neq)
|
| 701 |
+
neq_4 = BinConstraintD(d4, b4, op_neq)
|
| 702 |
+
|
| 703 |
+
# dimensions to inconsistent
|
| 704 |
+
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
|
| 705 |
+
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
|
| 706 |
+
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
|
| 707 |
+
dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
|
| 708 |
+
|
| 709 |
+
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
|
| 710 |
+
|
| 711 |
+
ne_constraint = Conj([input_is_size4, dims_inconsistent])
|
| 712 |
+
|
| 713 |
+
my_ne, counter = gen_bvar(counter)
|
| 714 |
+
|
| 715 |
+
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
| 716 |
+
|
| 717 |
+
else:
|
| 718 |
+
raise NotImplementedError('Method not yet implemented')
|
| 719 |
+
|
| 720 |
+
return [equality_constraint], counter
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
@register_inference_rule(operator.lt)
|
| 724 |
+
def lt_inference_rule(n: Node, symbols, constraints, counter):
|
| 725 |
+
assert isinstance(n.args[0], (Node, int))
|
| 726 |
+
assert isinstance(n.args[1], (Node, int))
|
| 727 |
+
|
| 728 |
+
# We make sure this node will not be used again. We do not
|
| 729 |
+
# generate a constraint about that node. Only about the operands.
|
| 730 |
+
|
| 731 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 732 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 733 |
+
|
| 734 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 735 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 736 |
+
lt_tensor, counter = gen_tvar(counter)
|
| 737 |
+
symbols[n] = lt_tensor
|
| 738 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
|
| 739 |
+
|
| 740 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 741 |
+
# This is meant to be used for flow analysis only
|
| 742 |
+
lt_constraint = BinConstraintD(e1, e2, op_lt)
|
| 743 |
+
|
| 744 |
+
my_lt, counter = gen_bvar(counter)
|
| 745 |
+
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
|
| 746 |
+
return [equality_constraint], counter
|
| 747 |
+
|
| 748 |
+
else:
|
| 749 |
+
raise RuntimeError('Sort Mismatch')
|
| 750 |
+
|
| 751 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 752 |
+
if isinstance(e1, DVar):
|
| 753 |
+
# This is meant to be used for flow analysis only
|
| 754 |
+
lt_constraint = BinConstraintD(e1, e2, op_lt)
|
| 755 |
+
|
| 756 |
+
my_lt, counter = gen_bvar(counter)
|
| 757 |
+
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
|
| 758 |
+
return [equality_constraint], counter
|
| 759 |
+
else:
|
| 760 |
+
raise NotImplementedError('Method not yet implemented')
|
| 761 |
+
|
| 762 |
+
else:
|
| 763 |
+
raise NotImplementedError('Method not yet implemented')
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
@register_inference_rule(torch.full)
|
| 767 |
+
def full_inference_rule(n: Node, symbols, constraints, counter):
|
| 768 |
+
full, counter = gen_tvar(counter)
|
| 769 |
+
symbols[n] = full
|
| 770 |
+
res = []
|
| 771 |
+
|
| 772 |
+
assert isinstance(n.args[0], Iterable)
|
| 773 |
+
for arg in n.args[0]:
|
| 774 |
+
dim = arg if isinstance(arg, int) else symbols[arg]
|
| 775 |
+
res.append(dim)
|
| 776 |
+
c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type]
|
| 777 |
+
return [c], counter
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
# TODO normalize index
|
| 781 |
+
@register_inference_rule(torch.arange)
|
| 782 |
+
def arange_inference_rule(n: Node, symbols, constraints, counter):
|
| 783 |
+
start = 0
|
| 784 |
+
step = 1
|
| 785 |
+
|
| 786 |
+
if len(n.args) == 1:
|
| 787 |
+
end = symbols[n.args[0]]
|
| 788 |
+
else:
|
| 789 |
+
raise NotImplementedError('Not yet implemented')
|
| 790 |
+
|
| 791 |
+
# int((end - start) / step)
|
| 792 |
+
d1, counter = gen_dvar(counter)
|
| 793 |
+
size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
|
| 794 |
+
arange, counter = gen_tvar(counter)
|
| 795 |
+
symbols[n] = arange
|
| 796 |
+
|
| 797 |
+
# either the a parameter is a number or it is Dyn
|
| 798 |
+
c1 = Disj([BinConstraintD(end, Dyn, op_eq),
|
| 799 |
+
BinConstraintD(start, Dyn, op_eq),
|
| 800 |
+
BinConstraintD(step, Dyn, op_eq)])
|
| 801 |
+
c2 = BinConstraintD(d1, Dyn, op_eq)
|
| 802 |
+
both_dyn = Conj([c1, c2])
|
| 803 |
+
|
| 804 |
+
c11 = Conj([BinConstraintD(end, Dyn, op_neq),
|
| 805 |
+
BinConstraintD(start, Dyn, op_neq),
|
| 806 |
+
BinConstraintD(step, Dyn, op_neq)])
|
| 807 |
+
c22 = BinConstraintD(d1, Dyn, op_neq)
|
| 808 |
+
both_numbers = Conj([c11, c22, size_constraint])
|
| 809 |
+
|
| 810 |
+
return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
|
| 811 |
+
|
| 812 |
+
def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
|
| 813 |
+
# additional vars that don't correspond to expressions
|
| 814 |
+
e11, counter = gen_tvar(counter)
|
| 815 |
+
e22, counter = gen_tvar(counter)
|
| 816 |
+
|
| 817 |
+
# generate constraints
|
| 818 |
+
c1 = TGreatestUpperBound(output_var, e11, e22)
|
| 819 |
+
c2 = ApplyBroadcasting(e11, e22, e1, e2)
|
| 820 |
+
c3 = BinConstraintT(e11, e22, op_consistency)
|
| 821 |
+
return [c1, c2, c3], counter
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
@register_inference_rule(operator.mul)
|
| 825 |
+
@register_inference_rule(torch.ne)
|
| 826 |
+
@register_inference_rule("ne")
|
| 827 |
+
@register_inference_rule(torch.add)
|
| 828 |
+
@register_inference_rule(operator.add)
|
| 829 |
+
def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
| 830 |
+
|
| 831 |
+
op_code = None
|
| 832 |
+
if n.target == operator.add or n.target == torch.add:
|
| 833 |
+
op_code = op_add
|
| 834 |
+
elif n.target == operator.mul:
|
| 835 |
+
op_code = op_mul
|
| 836 |
+
|
| 837 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 838 |
+
if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
|
| 839 |
+
my_output, counter = gen_tvar(counter)
|
| 840 |
+
symbols[n] = my_output
|
| 841 |
+
e1 = symbols[n.args[0]]
|
| 842 |
+
e2 = symbols[n.args[1]]
|
| 843 |
+
|
| 844 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
|
| 845 |
+
else:
|
| 846 |
+
raise NotImplementedError('Method not yet implemented')
|
| 847 |
+
|
| 848 |
+
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
|
| 849 |
+
if isinstance(symbols[n.args[0]], TVar):
|
| 850 |
+
my_output, counter = gen_tvar(counter)
|
| 851 |
+
symbols[n] = my_output
|
| 852 |
+
e1 = symbols[n.args[0]]
|
| 853 |
+
return [BinConstraintT(my_output, e1, op_eq)], counter
|
| 854 |
+
elif isinstance(symbols[n.args[0]], DVar):
|
| 855 |
+
my_output, counter = gen_dvar(counter)
|
| 856 |
+
symbols[n] = my_output
|
| 857 |
+
e1 = symbols[n.args[0]]
|
| 858 |
+
|
| 859 |
+
# we will propagate the runtime value here since this is regular addition
|
| 860 |
+
c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
|
| 861 |
+
BinConstraintD(0, my_output, op_leq)])
|
| 862 |
+
return [c], counter
|
| 863 |
+
|
| 864 |
+
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
|
| 865 |
+
if isinstance(symbols[n.args[1]], TVar):
|
| 866 |
+
my_output, counter = gen_tvar(counter)
|
| 867 |
+
symbols[n] = my_output
|
| 868 |
+
e2 = symbols[n.args[1]]
|
| 869 |
+
return [BinConstraintT(my_output, e2, op_eq)], counter
|
| 870 |
+
elif isinstance(symbols[n.args[1]], DVar):
|
| 871 |
+
my_output, counter = gen_dvar(counter)
|
| 872 |
+
symbols[n] = my_output
|
| 873 |
+
e2 = symbols[n.args[1]]
|
| 874 |
+
|
| 875 |
+
# we will propagate the runtime value here since this is regular addition
|
| 876 |
+
c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
|
| 877 |
+
BinConstraintD(0, my_output, op_leq)])
|
| 878 |
+
return [c], counter
|
| 879 |
+
|
| 880 |
+
else:
|
| 881 |
+
raise NotImplementedError('Method not yet implemented')
|
| 882 |
+
|
| 883 |
+
else:
|
| 884 |
+
# TODO generate add constraints for scalar addition
|
| 885 |
+
raise NotImplementedError('Addition not yet implemented')
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
@register_inference_rule(torch.flatten)
|
| 889 |
+
def flatten_inference_rule(n: Node, symbols, constraints, counter):
|
| 890 |
+
assert isinstance(n.args[0], Node)
|
| 891 |
+
|
| 892 |
+
# generate the new variable
|
| 893 |
+
flattened, counter = gen_tvar(counter)
|
| 894 |
+
symbols[n] = flattened
|
| 895 |
+
|
| 896 |
+
input = symbols[n.args[0]]
|
| 897 |
+
|
| 898 |
+
# set the default start and end dims
|
| 899 |
+
start_dim = 1
|
| 900 |
+
end_dim = -1
|
| 901 |
+
|
| 902 |
+
if len(n.args) > 1:
|
| 903 |
+
assert isinstance(n.args[1], int)
|
| 904 |
+
start_dim = n.args[1]
|
| 905 |
+
|
| 906 |
+
if len(n.args) > 2:
|
| 907 |
+
assert isinstance(n.args[2], int)
|
| 908 |
+
end_dim = n.args[2]
|
| 909 |
+
|
| 910 |
+
c1 = BinConstraintT(input, Dyn, op_eq)
|
| 911 |
+
c2 = BinConstraintT(flattened, Dyn, op_eq)
|
| 912 |
+
both_dyn = Conj([c1, c2])
|
| 913 |
+
|
| 914 |
+
const = []
|
| 915 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 916 |
+
c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
|
| 917 |
+
const.append(c)
|
| 918 |
+
|
| 919 |
+
return [Disj([both_dyn, *const])], counter
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
@register_inference_rule(torch.nn.functional.layer_norm)
|
| 923 |
+
def layer_norm_functional(n: Node, symbols, constraints, counter):
|
| 924 |
+
"""
|
| 925 |
+
We generate the constraint: input = output
|
| 926 |
+
"""
|
| 927 |
+
assert isinstance(n.args[0], Node)
|
| 928 |
+
return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
@register_inference_rule(torch.nn.LayerNorm)
|
| 932 |
+
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 933 |
+
"""
|
| 934 |
+
Input and output shapes should be equal.
|
| 935 |
+
Input should be consistent with the normalized_shape
|
| 936 |
+
"""
|
| 937 |
+
assert isinstance(n.args[0], Node)
|
| 938 |
+
return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
|
| 942 |
+
output, counter = gen_tvar(counter)
|
| 943 |
+
symbols[n] = output
|
| 944 |
+
input = symbols[n.args[0]]
|
| 945 |
+
|
| 946 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 947 |
+
output_dyn = BinConstraintT(output, Dyn, op_eq)
|
| 948 |
+
|
| 949 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 950 |
+
|
| 951 |
+
c2 = []
|
| 952 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 953 |
+
new_dims_rhs, counter = gen_tensor_dims(i, counter)
|
| 954 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs)
|
| 955 |
+
|
| 956 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
|
| 957 |
+
BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
|
| 958 |
+
add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
|
| 959 |
+
nat_constraints)
|
| 960 |
+
c2.append(c_tensor_i)
|
| 961 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 962 |
+
|
| 963 |
+
@register_inference_rule(torch.nn.Dropout)
|
| 964 |
+
@register_inference_rule(torch.nn.ReLU)
|
| 965 |
+
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 966 |
+
"""
|
| 967 |
+
Input and output shapes should be equal.
|
| 968 |
+
"""
|
| 969 |
+
assert isinstance(n.args[0], Node)
|
| 970 |
+
output, counter = gen_tvar(counter)
|
| 971 |
+
symbols[n] = output
|
| 972 |
+
input = symbols[n.args[0]]
|
| 973 |
+
assert isinstance(input, TVar)
|
| 974 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
@register_inference_rule(torch.nn.Linear)
|
| 978 |
+
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 979 |
+
"""
|
| 980 |
+
Input and output sizes should be the same except for the last dimension
|
| 981 |
+
If the input is Dyn, then so should the output
|
| 982 |
+
"""
|
| 983 |
+
assert isinstance(n.args[0], Node)
|
| 984 |
+
return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
@register_inference_rule("dim") # type: ignore[attr-defined]
|
| 988 |
+
def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
|
| 989 |
+
assert isinstance(n.args[0], Node)
|
| 990 |
+
my_dim, counter = gen_dvar(counter)
|
| 991 |
+
symbols[n] = my_dim
|
| 992 |
+
input = symbols[n.args[0]]
|
| 993 |
+
|
| 994 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 995 |
+
output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
|
| 996 |
+
|
| 997 |
+
c1 = []
|
| 998 |
+
|
| 999 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 1000 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 1001 |
+
|
| 1002 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
|
| 1003 |
+
BinConstraintD(my_dim, i, op_eq)])
|
| 1004 |
+
c1.append(c_tensor_i)
|
| 1005 |
+
|
| 1006 |
+
return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined]
|
| 1010 |
+
def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
|
| 1011 |
+
assert isinstance(n.args[0], Node)
|
| 1012 |
+
weight_dims, counter = gen_tensor_dims(2, counter)
|
| 1013 |
+
equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
|
| 1014 |
+
constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
|
| 1015 |
+
return [equality_constraint] + constraints, counter
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def linear_constraints(n: Node, in_features, out_features, symbols, counter):
|
| 1019 |
+
linear_output, counter = gen_tvar(counter)
|
| 1020 |
+
symbols[n] = linear_output
|
| 1021 |
+
linear_input = symbols[n.args[0]]
|
| 1022 |
+
|
| 1023 |
+
input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
|
| 1024 |
+
output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
|
| 1025 |
+
|
| 1026 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 1027 |
+
|
| 1028 |
+
c2 = []
|
| 1029 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 1030 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 1031 |
+
new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
|
| 1032 |
+
|
| 1033 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
|
| 1034 |
+
|
| 1035 |
+
c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
|
| 1036 |
+
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
|
| 1037 |
+
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
|
| 1038 |
+
nat_constraints)
|
| 1039 |
+
c2.append(c_tensor_i)
|
| 1040 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 1041 |
+
|
| 1042 |
+
def add_layer_norm_constraints(input_dim, normalized_dim):
|
| 1043 |
+
"""
|
| 1044 |
+
The constraints say that the type has te form: [*, 1024, 1024]
|
| 1045 |
+
while the normalized_dim have the form [1024, 1024]
|
| 1046 |
+
Args:
|
| 1047 |
+
input_dim: Input shape of layer norm
|
| 1048 |
+
normalized_dim: normalized_dim parameter of the module instance
|
| 1049 |
+
|
| 1050 |
+
"""
|
| 1051 |
+
|
| 1052 |
+
# in this case we return false since there's a pattern mismatch
|
| 1053 |
+
if len(normalized_dim) > len(input_dim):
|
| 1054 |
+
return [F()]
|
| 1055 |
+
|
| 1056 |
+
else:
|
| 1057 |
+
constraints = []
|
| 1058 |
+
for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
|
| 1059 |
+
constraints.append(BinConstraintD(i, n, op_consistency))
|
| 1060 |
+
return constraints
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
def add_linear_constraints(dims1, dims2, in_features, out_features):
|
| 1064 |
+
assert len(dims1) == len(dims2)
|
| 1065 |
+
constraints = []
|
| 1066 |
+
for i in range(len(dims1)):
|
| 1067 |
+
if i == len(dims1) - 1:
|
| 1068 |
+
constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
|
| 1069 |
+
constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
|
| 1070 |
+
else:
|
| 1071 |
+
constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
|
| 1072 |
+
|
| 1073 |
+
return constraints
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
@register_inference_rule(torch.reshape)
|
| 1077 |
+
def reshape_inference_rule(n: Node, symbols, constraints, counter):
|
| 1078 |
+
assert isinstance(n.args[0], Node)
|
| 1079 |
+
|
| 1080 |
+
# generate the new variable
|
| 1081 |
+
my_reshape, counter = gen_tvar(counter)
|
| 1082 |
+
symbols[n] = my_reshape
|
| 1083 |
+
|
| 1084 |
+
src_var = symbols[n.args[0]]
|
| 1085 |
+
t2 = n.args[1]
|
| 1086 |
+
t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr]
|
| 1087 |
+
c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr]
|
| 1088 |
+
c2 = CanReshape(src_var, t2_type)
|
| 1089 |
+
|
| 1090 |
+
return [c1, c2], counter
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
@register_inference_rule(BatchNorm2d)
|
| 1094 |
+
def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1095 |
+
assert isinstance(n.args[0], Node)
|
| 1096 |
+
|
| 1097 |
+
# generate the new variable
|
| 1098 |
+
batchnorm_output, counter = gen_tvar(counter)
|
| 1099 |
+
symbols[n] = batchnorm_output
|
| 1100 |
+
batchnorm_input = symbols[n.args[0]]
|
| 1101 |
+
|
| 1102 |
+
# dim vars
|
| 1103 |
+
d1, counter = gen_dvar(counter)
|
| 1104 |
+
d2, counter = gen_dvar(counter)
|
| 1105 |
+
d3, counter = gen_dvar(counter)
|
| 1106 |
+
d4, counter = gen_dvar(counter)
|
| 1107 |
+
|
| 1108 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1109 |
+
|
| 1110 |
+
c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1111 |
+
c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
|
| 1112 |
+
return [c1, c2, *nat_constraints], counter
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
|
| 1116 |
+
def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1117 |
+
assert isinstance(n.args[0], Node)
|
| 1118 |
+
|
| 1119 |
+
avg_pool, counter = gen_tvar(counter)
|
| 1120 |
+
|
| 1121 |
+
symbols[n] = avg_pool
|
| 1122 |
+
input_var = symbols[n.args[0]]
|
| 1123 |
+
|
| 1124 |
+
# dim vars
|
| 1125 |
+
d1, counter = gen_dvar(counter)
|
| 1126 |
+
d2, counter = gen_dvar(counter)
|
| 1127 |
+
d3, counter = gen_dvar(counter)
|
| 1128 |
+
d4, counter = gen_dvar(counter)
|
| 1129 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1130 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1131 |
+
c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
|
| 1132 |
+
|
| 1133 |
+
return [c1, c2, *nat_constraints], counter
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
@register_inference_rule(Conv2d)
|
| 1137 |
+
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1138 |
+
assert isinstance(n.args[0], Node)
|
| 1139 |
+
|
| 1140 |
+
my_conv, counter = gen_tvar(counter)
|
| 1141 |
+
symbols[n] = my_conv
|
| 1142 |
+
input_var = symbols[n.args[0]]
|
| 1143 |
+
|
| 1144 |
+
# dim vars
|
| 1145 |
+
[d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
|
| 1146 |
+
|
| 1147 |
+
# c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
|
| 1148 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1149 |
+
|
| 1150 |
+
# c2 = DConsistency(module_instance.in_channels, d2)
|
| 1151 |
+
c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
|
| 1152 |
+
|
| 1153 |
+
c3 = CalcConv(my_conv, input_var,
|
| 1154 |
+
module_instance.out_channels,
|
| 1155 |
+
module_instance.kernel_size,
|
| 1156 |
+
module_instance.padding,
|
| 1157 |
+
module_instance.stride,
|
| 1158 |
+
module_instance.dilation, [d1, d2, d3, d4])
|
| 1159 |
+
|
| 1160 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1161 |
+
|
| 1162 |
+
return [c1, c2, c3, *nat_constraints], counter
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
@register_inference_rule(torch.nn.MaxPool2d)
|
| 1166 |
+
def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1167 |
+
assert isinstance(n.args[0], Node)
|
| 1168 |
+
maxpool, counter = gen_tvar(counter)
|
| 1169 |
+
symbols[n] = maxpool
|
| 1170 |
+
input_var = symbols[n.args[0]]
|
| 1171 |
+
|
| 1172 |
+
# dim vars
|
| 1173 |
+
[d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
|
| 1174 |
+
|
| 1175 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1176 |
+
|
| 1177 |
+
c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
|
| 1178 |
+
module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
|
| 1179 |
+
|
| 1180 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1181 |
+
|
| 1182 |
+
return [c1, c2, *nat_constraints], counter
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
class ConstraintGenerator:
|
| 1186 |
+
def __init__(self, traced, graph=None):
|
| 1187 |
+
self.traced = traced # traced or tracer.root
|
| 1188 |
+
self.traced_params = dict(self.traced.named_parameters())
|
| 1189 |
+
self.constraints = []
|
| 1190 |
+
self.symbol_dict = {}
|
| 1191 |
+
self.graph = traced.graph if hasattr(traced, 'graph') else graph
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def generate_constraints(self, counter=0):
|
| 1195 |
+
"""
|
| 1196 |
+
Iterate through every node and generate constraints
|
| 1197 |
+
Effect: self.constraints will be populated with the final constraints
|
| 1198 |
+
"""
|
| 1199 |
+
graph = self.graph
|
| 1200 |
+
|
| 1201 |
+
all_constraints = []
|
| 1202 |
+
|
| 1203 |
+
for n in graph.nodes:
|
| 1204 |
+
(constraints, counter) = self.generate_constraints_node(n, counter)
|
| 1205 |
+
all_constraints += constraints
|
| 1206 |
+
|
| 1207 |
+
return Conj(all_constraints), counter
|
| 1208 |
+
|
| 1209 |
+
def generate_constraints_node(self, n: Node, counter):
|
| 1210 |
+
"""
|
| 1211 |
+
Generate constraints the given node:
|
| 1212 |
+
Currently supported operations:
|
| 1213 |
+
- Reshape
|
| 1214 |
+
- Add
|
| 1215 |
+
- conv2d
|
| 1216 |
+
"""
|
| 1217 |
+
|
| 1218 |
+
if n.op == 'placeholder':
|
| 1219 |
+
x, counter = gen_tvar(counter)
|
| 1220 |
+
self.symbol_dict[n] = x
|
| 1221 |
+
|
| 1222 |
+
my_type = n.type
|
| 1223 |
+
|
| 1224 |
+
if n.type != Dyn and (not isinstance(n.type, TensorType)):
|
| 1225 |
+
if n.type == torch.nn.parameter.Parameter:
|
| 1226 |
+
# since we have a parameter, the shape must be static
|
| 1227 |
+
assert 'example_value' in n.meta
|
| 1228 |
+
my_type = TensorType(n.meta['example_value'].size())
|
| 1229 |
+
else:
|
| 1230 |
+
my_type = Dyn
|
| 1231 |
+
|
| 1232 |
+
c1 = BinConstraintT(my_type, x, op_precision)
|
| 1233 |
+
c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
|
| 1234 |
+
return [c1, c2], counter
|
| 1235 |
+
|
| 1236 |
+
elif n.op == 'call_function':
|
| 1237 |
+
if n.target in _INFERENCE_RULES:
|
| 1238 |
+
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
| 1239 |
+
else:
|
| 1240 |
+
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
| 1241 |
+
|
| 1242 |
+
elif n.op == 'call_module':
|
| 1243 |
+
|
| 1244 |
+
module_instance = self.traced.get_submodule(n.target)
|
| 1245 |
+
if type(module_instance) in _INFERENCE_RULES:
|
| 1246 |
+
return _INFERENCE_RULES[type(module_instance)](n,
|
| 1247 |
+
module_instance,
|
| 1248 |
+
self.symbol_dict,
|
| 1249 |
+
self.constraints, counter)
|
| 1250 |
+
else:
|
| 1251 |
+
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
|
| 1252 |
+
|
| 1253 |
+
elif n.op == 'call_method':
|
| 1254 |
+
if n.target in _INFERENCE_RULES:
|
| 1255 |
+
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
| 1256 |
+
else:
|
| 1257 |
+
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
| 1258 |
+
|
| 1259 |
+
elif n.op == 'get_attr':
|
| 1260 |
+
t = self.traced_params.get(n.target, None)
|
| 1261 |
+
|
| 1262 |
+
if isinstance(t, torch.Tensor):
|
| 1263 |
+
if len(t.shape) > 0:
|
| 1264 |
+
res = list(t.shape)
|
| 1265 |
+
attr_type = TensorType(res)
|
| 1266 |
+
output, counter = gen_tvar(counter)
|
| 1267 |
+
self.symbol_dict[n] = output
|
| 1268 |
+
return [BinConstraintT(output, attr_type, op_eq)], counter
|
| 1269 |
+
else:
|
| 1270 |
+
# scalar?
|
| 1271 |
+
return [], counter
|
| 1272 |
+
else:
|
| 1273 |
+
return [], counter
|
| 1274 |
+
|
| 1275 |
+
elif n.op == 'output':
|
| 1276 |
+
return [], counter
|
| 1277 |
+
|
| 1278 |
+
else:
|
| 1279 |
+
raise NotImplementedError(f"Method {n.op} not yet implemented")
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
import copy
|
| 3 |
+
import itertools
|
| 4 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
|
| 5 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
|
| 6 |
+
Transpose
|
| 7 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
|
| 8 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
|
| 9 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
|
| 10 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
|
| 11 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
|
| 12 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
|
| 13 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
|
| 14 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
|
| 15 |
+
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
|
| 16 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 17 |
+
from typing import Callable, Dict, List
|
| 18 |
+
|
| 19 |
+
_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def register_transformation_rule(call_target):
|
| 23 |
+
def register(fn):
|
| 24 |
+
if call_target in _TRANSFORMATION_RULES:
|
| 25 |
+
raise RuntimeError(f'Transformation rule already registered for {call_target}!')
|
| 26 |
+
_TRANSFORMATION_RULES[call_target] = fn
|
| 27 |
+
return fn
|
| 28 |
+
return register
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def valid_index(index, dims):
|
| 32 |
+
"""
|
| 33 |
+
Given a list of dimensions, checks if an index is valid in the list
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
dims[index]
|
| 37 |
+
return T()
|
| 38 |
+
except IndexError:
|
| 39 |
+
return F()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@register_transformation_rule(Transpose)
|
| 43 |
+
def transform_transpose(constraint, counter):
|
| 44 |
+
"""
|
| 45 |
+
Similar to a sequence of two index-selects
|
| 46 |
+
"""
|
| 47 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 48 |
+
is_valid_index1 = valid_index(constraint.index1, dims)
|
| 49 |
+
is_valid_index2 = valid_index(constraint.index2, dims)
|
| 50 |
+
new_dims = copy.deepcopy(dims)
|
| 51 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 52 |
+
|
| 53 |
+
if is_valid_index1 == T() and is_valid_index2 == T():
|
| 54 |
+
new_dims[constraint.index1] = dims[constraint.index2]
|
| 55 |
+
new_dims[constraint.index2] = dims[constraint.index1]
|
| 56 |
+
|
| 57 |
+
transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 58 |
+
*nat_constraints,
|
| 59 |
+
is_valid_index1, is_valid_index2,
|
| 60 |
+
BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
|
| 61 |
+
return transformed_constraint, counter
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@register_transformation_rule(IndexSelect)
|
| 65 |
+
def transform_index_select(constraint, counter):
|
| 66 |
+
"""
|
| 67 |
+
The constraints consider the given tensor size, checks if the index is valid
|
| 68 |
+
and if so, generates a constraint for replacing the input dimension
|
| 69 |
+
with the required dimension
|
| 70 |
+
"""
|
| 71 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 72 |
+
is_valid_index = valid_index(constraint.index, dims)
|
| 73 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 74 |
+
|
| 75 |
+
# if the index is valid then replace the input dimension with the new dimension
|
| 76 |
+
# otherwise the dimension will not be replaced and the clause will contain False
|
| 77 |
+
if is_valid_index == T():
|
| 78 |
+
new_dims = copy.deepcopy(dims)
|
| 79 |
+
new_dims[constraint.index] = constraint.dim_replace
|
| 80 |
+
|
| 81 |
+
transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 82 |
+
*nat_constraints,
|
| 83 |
+
is_valid_index,
|
| 84 |
+
BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
|
| 85 |
+
|
| 86 |
+
# print(constraints)
|
| 87 |
+
return transformed_constraint, counter
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@register_transformation_rule(GetItem)
|
| 91 |
+
def transform_get_item(constraint, counter):
|
| 92 |
+
"""
|
| 93 |
+
generate an equality of the form:
|
| 94 |
+
t = [a1, ..., an]
|
| 95 |
+
then generate constraints that check if the given index is valid
|
| 96 |
+
given this particular tensor size.
|
| 97 |
+
If the index is valid, generate a constraint to get the item
|
| 98 |
+
Note that we already handled the Dyn input case in the previous
|
| 99 |
+
step.
|
| 100 |
+
Args:
|
| 101 |
+
constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
|
| 102 |
+
counter: variable tracking
|
| 103 |
+
Returns: simplified constraints for GetItem
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 107 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
is_valid_index = valid_index(constraint.index, dims)
|
| 111 |
+
|
| 112 |
+
all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 113 |
+
*nat_constraints,
|
| 114 |
+
is_valid_index]
|
| 115 |
+
|
| 116 |
+
# if the index is valid, we generate a constraint for getting an item
|
| 117 |
+
# otherwise this clause will have been UNSAT due to the wrong index
|
| 118 |
+
if is_valid_index == T():
|
| 119 |
+
all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
|
| 120 |
+
|
| 121 |
+
return Conj(all_constraints), counter
|
| 122 |
+
|
| 123 |
+
def valid_index_tensor(index, dims):
|
| 124 |
+
"""
|
| 125 |
+
if the slice instances exceed the length of the dimensions
|
| 126 |
+
then this is a type error so we return False
|
| 127 |
+
"""
|
| 128 |
+
slice_count = 0
|
| 129 |
+
for s in index:
|
| 130 |
+
if isinstance(s, slice):
|
| 131 |
+
slice_count += 1
|
| 132 |
+
if slice_count > len(dims):
|
| 133 |
+
return F()
|
| 134 |
+
else:
|
| 135 |
+
return T()
|
| 136 |
+
|
| 137 |
+
@register_transformation_rule(GetItemTensor)
|
| 138 |
+
def transform_get_item_tensor(constraint, counter):
|
| 139 |
+
"""
|
| 140 |
+
When the index is a tuple, then the output will be a tensor
|
| 141 |
+
TODO: we have to check if this is the case for all HF models
|
| 142 |
+
|
| 143 |
+
The cases we are covering here are a tuple with one of:
|
| 144 |
+
- slice with default argument
|
| 145 |
+
- None
|
| 146 |
+
|
| 147 |
+
None appends 1 to the input tensor dimensions
|
| 148 |
+
so each occurrence of 'None' increases the rank by 1
|
| 149 |
+
|
| 150 |
+
slice with default arguments does not change the rank
|
| 151 |
+
"""
|
| 152 |
+
assert isinstance(constraint.index_tuple, tuple)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# generate a result tensor of the expected size
|
| 156 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 157 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 158 |
+
|
| 159 |
+
# generate a place-holder list of the right rank
|
| 160 |
+
# where "slice" does not contribute to the rank and "None" does
|
| 161 |
+
none_c = constraint.index_tuple.count(None)
|
| 162 |
+
resulting_tensor_dims = (none_c + len(dims)) * [None]
|
| 163 |
+
|
| 164 |
+
dim_index = 0
|
| 165 |
+
for i in range(len(constraint.index_tuple)):
|
| 166 |
+
|
| 167 |
+
# append 1 to the right location of the resulting tensor
|
| 168 |
+
if constraint.index_tuple[i] is None:
|
| 169 |
+
resulting_tensor_dims[i] = 1
|
| 170 |
+
|
| 171 |
+
elif constraint.index_tuple[i] == slice(None, None, None):
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
raise NotImplementedError('Method not yet implemented')
|
| 176 |
+
|
| 177 |
+
# append the remaining dimensions to the right location
|
| 178 |
+
dim_index = 0
|
| 179 |
+
for i in range(len(resulting_tensor_dims)):
|
| 180 |
+
if resulting_tensor_dims[i] is None:
|
| 181 |
+
resulting_tensor_dims[i] = dims[dim_index]
|
| 182 |
+
dim_index += 1
|
| 183 |
+
|
| 184 |
+
# check if the index is valid
|
| 185 |
+
is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
|
| 186 |
+
|
| 187 |
+
# check if the resulting tensor is within bounds
|
| 188 |
+
if len(resulting_tensor_dims) > 4:
|
| 189 |
+
return F(), counter
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 193 |
+
BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
|
| 194 |
+
*nat_constraints,
|
| 195 |
+
is_valid_index]
|
| 196 |
+
return Conj(constraints), counter
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@register_transformation_rule(BinConstraintT)
|
| 200 |
+
def generate_binconstraint_t(constraint, counter):
|
| 201 |
+
"""
|
| 202 |
+
Transform binary constraints for tensors
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# precision constraints
|
| 206 |
+
if constraint.op == op_precision:
|
| 207 |
+
if constraint.lhs == Dyn:
|
| 208 |
+
return T(), counter
|
| 209 |
+
elif isinstance(constraint.lhs, TensorType):
|
| 210 |
+
is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
|
| 211 |
+
if is_fully_static:
|
| 212 |
+
return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
|
| 213 |
+
else:
|
| 214 |
+
new_dims = []
|
| 215 |
+
|
| 216 |
+
for _ in range(len(constraint.lhs.__args__)):
|
| 217 |
+
dim, counter = gen_dvar(counter)
|
| 218 |
+
new_dims.append(dim)
|
| 219 |
+
|
| 220 |
+
new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
|
| 221 |
+
new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
|
| 222 |
+
[BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
|
| 223 |
+
[BinConstraintD(1, new_dim, op_leq) for
|
| 224 |
+
new_dim in new_dims]
|
| 225 |
+
return Conj(new_dim_constraints), counter
|
| 226 |
+
|
| 227 |
+
# matching
|
| 228 |
+
elif constraint.op == op_matching:
|
| 229 |
+
assert isinstance(constraint.rhs, TensorType)
|
| 230 |
+
d1 = constraint.rhs.__args__[0]
|
| 231 |
+
d2 = constraint.rhs.__args__[1]
|
| 232 |
+
d3 = constraint.rhs.__args__[2]
|
| 233 |
+
d4 = constraint.rhs.__args__[3]
|
| 234 |
+
|
| 235 |
+
conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
|
| 236 |
+
BinConstraintD(d1, Dyn, op_eq),
|
| 237 |
+
BinConstraintD(d2, Dyn, op_eq),
|
| 238 |
+
BinConstraintD(d3, Dyn, op_eq),
|
| 239 |
+
BinConstraintD(d4, Dyn, op_eq)]
|
| 240 |
+
return Disj([Conj(conj),
|
| 241 |
+
BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
|
| 242 |
+
|
| 243 |
+
elif constraint.op == op_consistency:
|
| 244 |
+
c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
|
| 245 |
+
[c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
|
| 246 |
+
|
| 247 |
+
return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
|
| 248 |
+
|
| 249 |
+
elif constraint.op == op_leq:
|
| 250 |
+
assert isinstance(constraint.rhs, int)
|
| 251 |
+
disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
|
| 252 |
+
for i in range(1, constraint.rhs + 1):
|
| 253 |
+
dims = []
|
| 254 |
+
for j in range(1, i + 1):
|
| 255 |
+
dim_var, counter = gen_dvar(counter)
|
| 256 |
+
dims.append(dim_var)
|
| 257 |
+
disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
|
| 258 |
+
return Disj(disj), counter
|
| 259 |
+
else:
|
| 260 |
+
return constraint, counter
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@register_transformation_rule(BinConstraintD)
|
| 264 |
+
def generate_binconstraint_d(constraint, counter):
|
| 265 |
+
"""
|
| 266 |
+
Transform binary constraints for dimensions
|
| 267 |
+
"""
|
| 268 |
+
if constraint.op == op_precision:
|
| 269 |
+
if isinstance(constraint.lhs, int):
|
| 270 |
+
return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
|
| 271 |
+
elif constraint.lhs == Dyn:
|
| 272 |
+
return T(), counter
|
| 273 |
+
|
| 274 |
+
elif constraint.op == op_consistency:
|
| 275 |
+
return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
|
| 276 |
+
BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
|
| 277 |
+
|
| 278 |
+
else:
|
| 279 |
+
return constraint, counter
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@register_transformation_rule(Conj)
|
| 283 |
+
def generate_conj(constraint, counter):
|
| 284 |
+
"""
|
| 285 |
+
Transform conjunctions
|
| 286 |
+
"""
|
| 287 |
+
new = []
|
| 288 |
+
for c in constraint.conjucts:
|
| 289 |
+
new_c, counter = transform_constraint(c, counter)
|
| 290 |
+
new.append(new_c)
|
| 291 |
+
return Conj(new), counter
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@register_transformation_rule(Disj)
|
| 295 |
+
def generate_disj(constraint, counter):
|
| 296 |
+
"""
|
| 297 |
+
Transform disjunctions
|
| 298 |
+
"""
|
| 299 |
+
new = []
|
| 300 |
+
for c in constraint.disjuncts:
|
| 301 |
+
new_c, counter = transform_constraint(c, counter)
|
| 302 |
+
new.append(new_c)
|
| 303 |
+
return Disj(new), counter
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@register_transformation_rule(TGreatestUpperBound)
|
| 307 |
+
def generate_gub(constraint, counter):
|
| 308 |
+
"""
|
| 309 |
+
Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
|
| 310 |
+
on dimensions
|
| 311 |
+
"""
|
| 312 |
+
c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
|
| 313 |
+
BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
|
| 314 |
+
|
| 315 |
+
[c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
|
| 316 |
+
|
| 317 |
+
return Disj([c1, c2, c3, c4, c5]), counter
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@register_transformation_rule(DGreatestUpperBound)
|
| 321 |
+
def generate_d_gub(constraint, counter):
|
| 322 |
+
"""
|
| 323 |
+
Transform greatest upper bound for dimensions into equality constraints
|
| 324 |
+
"""
|
| 325 |
+
c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
|
| 326 |
+
c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
|
| 327 |
+
c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
|
| 328 |
+
return Disj([c1, c2, c3]), counter
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@register_transformation_rule(CalcConv)
|
| 332 |
+
def generate_calc_conv(constraint, counter):
|
| 333 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 334 |
+
conv_result = TensorType([d[0], d[1], d[2], d[3]])
|
| 335 |
+
|
| 336 |
+
# the convolution result is a tensor of size 4
|
| 337 |
+
c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
|
| 338 |
+
|
| 339 |
+
# the second dimension of the output is equal to the output channels
|
| 340 |
+
c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
|
| 341 |
+
|
| 342 |
+
# the input corresponds to the output in the first dimension of the convolution
|
| 343 |
+
c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
|
| 344 |
+
|
| 345 |
+
c4, c5 = calc_last_two_dims(constraint, d)
|
| 346 |
+
|
| 347 |
+
leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
|
| 348 |
+
BinConstraintD(0, d[1], op_leq),
|
| 349 |
+
BinConstraintD(0, d[2], op_leq),
|
| 350 |
+
BinConstraintD(0, d[3], op_leq)])
|
| 351 |
+
|
| 352 |
+
return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@register_transformation_rule(CalcMaxPool)
|
| 356 |
+
def generate_calc_maxpool(constraint, counter):
|
| 357 |
+
"""
|
| 358 |
+
Transform maxpool constraints
|
| 359 |
+
"""
|
| 360 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 361 |
+
maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
|
| 362 |
+
|
| 363 |
+
# the maxpool result is a tensor of size 4
|
| 364 |
+
c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
|
| 365 |
+
|
| 366 |
+
# the input corresponds to the output in the first and second dimension of maxpool
|
| 367 |
+
c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
|
| 368 |
+
c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
|
| 369 |
+
c4, c5 = calc_last_two_dims(constraint, d)
|
| 370 |
+
|
| 371 |
+
leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
|
| 372 |
+
BinConstraintD(0, d[1], op_leq),
|
| 373 |
+
BinConstraintD(0, d[2], op_leq),
|
| 374 |
+
BinConstraintD(0, d[3], op_leq)])
|
| 375 |
+
|
| 376 |
+
return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@register_transformation_rule(CalcProduct)
|
| 380 |
+
def generate_calc_product(constraint, counter):
|
| 381 |
+
"""
|
| 382 |
+
Transform flatten constraints
|
| 383 |
+
"""
|
| 384 |
+
start = constraint.start
|
| 385 |
+
end = constraint.end
|
| 386 |
+
dims = constraint.dims_to_flatten
|
| 387 |
+
flattened = constraint.flattened
|
| 388 |
+
n = len(constraint.dims_to_flatten)
|
| 389 |
+
|
| 390 |
+
# this will be evaluated right here
|
| 391 |
+
boundary_check = (0 <= start and start < end and end <= n)
|
| 392 |
+
|
| 393 |
+
c_boundary = T() if boundary_check else F()
|
| 394 |
+
|
| 395 |
+
lhs = dims[0:start]
|
| 396 |
+
rhs = dims[end:]
|
| 397 |
+
mid = dims[start:end]
|
| 398 |
+
|
| 399 |
+
all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
|
| 400 |
+
|
| 401 |
+
all_constraints = []
|
| 402 |
+
|
| 403 |
+
for p in all_possibilities:
|
| 404 |
+
p = list(p)
|
| 405 |
+
# this tells us there is a dynamic variable
|
| 406 |
+
contains_dyn = not all(constraint.op == op_neq for constraint in p)
|
| 407 |
+
if contains_dyn:
|
| 408 |
+
mid_var = [Dyn]
|
| 409 |
+
total_constraints = lhs + mid_var + rhs
|
| 410 |
+
if len(total_constraints) > 4:
|
| 411 |
+
all_constraints.append(F())
|
| 412 |
+
else:
|
| 413 |
+
all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
|
| 414 |
+
else:
|
| 415 |
+
new_var, counter = gen_dvar(counter)
|
| 416 |
+
mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
|
| 417 |
+
mid_var = [new_var]
|
| 418 |
+
total_constraints = lhs + mid_var + rhs
|
| 419 |
+
if len(total_constraints) > 4:
|
| 420 |
+
all_constraints.append(F())
|
| 421 |
+
else:
|
| 422 |
+
all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
|
| 423 |
+
|
| 424 |
+
return Conj([Disj(all_constraints), c_boundary]), counter
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@register_transformation_rule(CanReshape)
|
| 428 |
+
def generate_reshape(constraint, counter):
|
| 429 |
+
"""
|
| 430 |
+
Transform reshape constraints
|
| 431 |
+
"""
|
| 432 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 433 |
+
|
| 434 |
+
d1 = d[0]
|
| 435 |
+
d2 = d[1]
|
| 436 |
+
d3 = d[2]
|
| 437 |
+
d4 = d[3]
|
| 438 |
+
|
| 439 |
+
target = constraint.target.__args__
|
| 440 |
+
|
| 441 |
+
is_fully_static = all(d != Dyn for d in target)
|
| 442 |
+
|
| 443 |
+
# dynamic tensor
|
| 444 |
+
c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
|
| 445 |
+
c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
|
| 446 |
+
c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
|
| 447 |
+
c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
|
| 448 |
+
c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
|
| 449 |
+
|
| 450 |
+
d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
|
| 451 |
+
d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
|
| 452 |
+
|
| 453 |
+
d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
|
| 454 |
+
d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
|
| 455 |
+
|
| 456 |
+
d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
|
| 457 |
+
d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
|
| 458 |
+
|
| 459 |
+
d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
|
| 460 |
+
d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
|
| 461 |
+
|
| 462 |
+
nat_d1 = BinConstraintD(0, d1, op_leq)
|
| 463 |
+
nat_d2 = BinConstraintD(0, d2, op_leq)
|
| 464 |
+
nat_d3 = BinConstraintD(0, d3, op_leq)
|
| 465 |
+
nat_d4 = BinConstraintD(0, d4, op_leq)
|
| 466 |
+
|
| 467 |
+
if is_fully_static:
|
| 468 |
+
# size 1 tensor
|
| 469 |
+
c3_tensor1 = Disj([d1_eq_dyn,
|
| 470 |
+
(Conj([d1_neq_dyn,
|
| 471 |
+
BinConstraintD(d1, Prod(target), op_eq)]))])
|
| 472 |
+
all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
|
| 473 |
+
|
| 474 |
+
# size 2 tensor
|
| 475 |
+
all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
|
| 476 |
+
|
| 477 |
+
# size 3 tensor
|
| 478 |
+
all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
|
| 479 |
+
|
| 480 |
+
# size 4 tensor
|
| 481 |
+
all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
|
| 482 |
+
|
| 483 |
+
return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
|
| 484 |
+
nat_d1, nat_d2, nat_d3, nat_d4]), counter
|
| 485 |
+
|
| 486 |
+
# then there must be exactly one occurrence of dyn
|
| 487 |
+
else:
|
| 488 |
+
new_target = []
|
| 489 |
+
|
| 490 |
+
for n in target:
|
| 491 |
+
if n != Dyn:
|
| 492 |
+
new_target.append(n)
|
| 493 |
+
|
| 494 |
+
# tensor 1
|
| 495 |
+
c3_tensor1 = Disj([d1_eq_dyn,
|
| 496 |
+
(Conj([d1_neq_dyn,
|
| 497 |
+
is_dim_div_by_target(new_target, d1)]))])
|
| 498 |
+
all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
|
| 499 |
+
|
| 500 |
+
# tensor 2
|
| 501 |
+
c21 = Disj([d1_eq_dyn, d2_eq_dyn])
|
| 502 |
+
c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
|
| 503 |
+
all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
|
| 504 |
+
|
| 505 |
+
# tensor 3
|
| 506 |
+
c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
|
| 507 |
+
c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
|
| 508 |
+
all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
|
| 509 |
+
|
| 510 |
+
# tensor 4
|
| 511 |
+
c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
|
| 512 |
+
c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
|
| 513 |
+
all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
|
| 514 |
+
|
| 515 |
+
return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
|
| 516 |
+
nat_d1, nat_d2, nat_d3, nat_d4]), counter
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
@register_transformation_rule(ApplyBroadcasting)
|
| 520 |
+
def generate_broadcasting(constraint, counter):
|
| 521 |
+
"""
|
| 522 |
+
Transform broadcasting constraints
|
| 523 |
+
"""
|
| 524 |
+
e11, e12 = constraint.res1, constraint.res2
|
| 525 |
+
e1, e2 = constraint.input1, constraint.input2
|
| 526 |
+
|
| 527 |
+
e1_dyn = BinConstraintT(e1, Dyn, op_eq)
|
| 528 |
+
e2_dyn = BinConstraintT(e2, Dyn, op_eq)
|
| 529 |
+
|
| 530 |
+
# Introduce dimensions
|
| 531 |
+
e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
|
| 532 |
+
e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
|
| 533 |
+
|
| 534 |
+
# dyn possibility
|
| 535 |
+
e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
|
| 536 |
+
e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
|
| 537 |
+
|
| 538 |
+
# tensor possibility
|
| 539 |
+
# generate dimensions to create tensors of size 1
|
| 540 |
+
final_tensor_1_constraint, _, _, nat_dims_1, counter = \
|
| 541 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
|
| 542 |
+
|
| 543 |
+
# generate dimensions to create tensors of size 2
|
| 544 |
+
final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
|
| 545 |
+
final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
|
| 546 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
|
| 547 |
+
|
| 548 |
+
# generate dimensions to create tensors of size 3
|
| 549 |
+
final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
|
| 550 |
+
final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
|
| 551 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
|
| 552 |
+
|
| 553 |
+
# generate dimensions to create tensors of size 4
|
| 554 |
+
final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
|
| 555 |
+
final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
|
| 556 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
|
| 557 |
+
|
| 558 |
+
final_result = Disj([
|
| 559 |
+
e1_dyn_constraint,
|
| 560 |
+
e2_dyn_constraint,
|
| 561 |
+
final_tensor_1_constraint,
|
| 562 |
+
final_tensor_2_constraint_no_padding,
|
| 563 |
+
final_tensor_2_constraint_padding_arg1,
|
| 564 |
+
final_tensor_2_constraint_padding_arg2,
|
| 565 |
+
final_tensor_3_constraint_no_padding,
|
| 566 |
+
final_tensor_3_constraint_padding_arg1,
|
| 567 |
+
final_tensor_3_constraint_padding_arg2,
|
| 568 |
+
final_tensor_4_constraint_no_padding,
|
| 569 |
+
final_tensor_4_constraint_padding_arg1,
|
| 570 |
+
final_tensor_4_constraint_padding_arg2
|
| 571 |
+
])
|
| 572 |
+
|
| 573 |
+
return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def transform_constraint(constraint: Constraint, counter: int):
|
| 577 |
+
"""
|
| 578 |
+
Transforms a constraint into a simpler constraint.
|
| 579 |
+
Ex: precision and consistency are transformed to equality
|
| 580 |
+
Args:
|
| 581 |
+
constraint: constraint to be transformed
|
| 582 |
+
counter: for variable tracking
|
| 583 |
+
|
| 584 |
+
Returns: Constraint
|
| 585 |
+
|
| 586 |
+
"""
|
| 587 |
+
if type(constraint) in _TRANSFORMATION_RULES:
|
| 588 |
+
return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
|
| 589 |
+
|
| 590 |
+
else:
|
| 591 |
+
return constraint, counter
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def calc_last_two_dims(constraint, d: List[DVar]):
|
| 597 |
+
"""
|
| 598 |
+
Generates constraints for the last two dimensions of a convolution or a maxpool output
|
| 599 |
+
Args:
|
| 600 |
+
constraint: CalcConv or CalcMaxPool
|
| 601 |
+
d: The list of output dimensions
|
| 602 |
+
|
| 603 |
+
Returns: Constraints for calculating the last two dimensions of the output
|
| 604 |
+
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
assert isinstance(constraint, (CalcConv, CalcMaxPool))
|
| 608 |
+
|
| 609 |
+
b3 = constraint.matching_constraint[2]
|
| 610 |
+
b4 = constraint.matching_constraint[3]
|
| 611 |
+
|
| 612 |
+
b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
|
| 613 |
+
b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
|
| 614 |
+
|
| 615 |
+
d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
|
| 616 |
+
d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
|
| 617 |
+
|
| 618 |
+
# transform parameters into tuples incase they are not already
|
| 619 |
+
padding = (constraint.padding, constraint.padding) \
|
| 620 |
+
if isinstance(constraint.padding, int) else constraint.padding
|
| 621 |
+
kernel = (constraint.kernel, constraint.kernel) \
|
| 622 |
+
if isinstance(constraint.kernel, int) else constraint.kernel
|
| 623 |
+
stride = (constraint.stride, constraint.stride) \
|
| 624 |
+
if isinstance(constraint.stride, int) else constraint.stride
|
| 625 |
+
dilation = (constraint.dilation, constraint.dilation) \
|
| 626 |
+
if isinstance(constraint.dilation, int) else constraint.dilation
|
| 627 |
+
|
| 628 |
+
f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
|
| 629 |
+
f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
|
| 630 |
+
f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
|
| 631 |
+
f4 = BinConstraintD(f3, 1, op_add)
|
| 632 |
+
|
| 633 |
+
c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
|
| 634 |
+
|
| 635 |
+
f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
|
| 636 |
+
f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
|
| 637 |
+
f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
|
| 638 |
+
f44 = BinConstraintD(f33, 1, op_add)
|
| 639 |
+
|
| 640 |
+
c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
|
| 641 |
+
|
| 642 |
+
return c4, c5
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
|
| 646 |
+
"""
|
| 647 |
+
Generate all possibilities of being equal or not equal to dyn for my_list
|
| 648 |
+
Args:
|
| 649 |
+
my_list: List of tensor dimensions
|
| 650 |
+
|
| 651 |
+
Returns: A list of a list of constraints. Each list of constraints corresponds to
|
| 652 |
+
one possibility about the values of the dimension variables
|
| 653 |
+
"""
|
| 654 |
+
# generate all possibilities of being equal or not equal to dyn for my_list
|
| 655 |
+
eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
|
| 656 |
+
neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
|
| 657 |
+
d_possibilities = []
|
| 658 |
+
|
| 659 |
+
for i in zip(eq_possibilities, neq_possibilities):
|
| 660 |
+
d_possibilities.append(list(i))
|
| 661 |
+
all_possibilities = list(itertools.product(*d_possibilities))
|
| 662 |
+
return all_possibilities
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def is_target_div_by_dim(target: List[int], dim: List[DVar]):
|
| 666 |
+
"""
|
| 667 |
+
Generate constraints to check if the target dimensions are divisible by the input dimensions
|
| 668 |
+
Args:
|
| 669 |
+
target: Target dimensions
|
| 670 |
+
dim: Input dimensions
|
| 671 |
+
|
| 672 |
+
Returns: Constraints to check divisibility
|
| 673 |
+
|
| 674 |
+
"""
|
| 675 |
+
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def is_dim_div_by_target(target: List[int], dim: List[DVar]):
|
| 679 |
+
"""
|
| 680 |
+
Generate constraints to check if the input dimensions is divisible by the target dimensions
|
| 681 |
+
Args:
|
| 682 |
+
target: Target dimensions
|
| 683 |
+
dim: Input dimensions
|
| 684 |
+
|
| 685 |
+
Returns: Constraints to check divisibility
|
| 686 |
+
|
| 687 |
+
"""
|
| 688 |
+
return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def gen_all_reshape_possibilities(list_of_dims, target):
|
| 692 |
+
"""
|
| 693 |
+
Consider all possibilities what the input dimensions could be (number or dynamic)
|
| 694 |
+
Then generate the appropriate constraints using multiplication or mod depending on the possibility
|
| 695 |
+
The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
|
| 696 |
+
for the input. Target is fixed because at most one dimension could be dyn.
|
| 697 |
+
We have different cases for this.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
list_of_dims: The input list of dimensions
|
| 701 |
+
target: The tensor we want to reshape to
|
| 702 |
+
|
| 703 |
+
Returns: A disjunction of transformed reshape constraints
|
| 704 |
+
|
| 705 |
+
"""
|
| 706 |
+
all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
|
| 707 |
+
|
| 708 |
+
all_constraints = []
|
| 709 |
+
|
| 710 |
+
for p in all_possibilities:
|
| 711 |
+
to_multiply = []
|
| 712 |
+
|
| 713 |
+
p = list(p)
|
| 714 |
+
|
| 715 |
+
for constraint in p:
|
| 716 |
+
assert isinstance(constraint, BinConstraintD)
|
| 717 |
+
if constraint.op == op_neq:
|
| 718 |
+
to_multiply.append(constraint.lhs)
|
| 719 |
+
|
| 720 |
+
if not to_multiply:
|
| 721 |
+
all_constraints.append(Conj(p))
|
| 722 |
+
|
| 723 |
+
elif len(to_multiply) < len(list_of_dims):
|
| 724 |
+
all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
|
| 725 |
+
else:
|
| 726 |
+
all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
|
| 727 |
+
Prod(target), op_eq)]))
|
| 728 |
+
|
| 729 |
+
return Disj(all_constraints)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
|
| 733 |
+
"""
|
| 734 |
+
Apply broadcasting to the 'index' dimension of tensor_input1.
|
| 735 |
+
Args:
|
| 736 |
+
tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
|
| 737 |
+
tensor_input2: represents the second input
|
| 738 |
+
res1: broadcasted result 1
|
| 739 |
+
res2: broadcasted result 2
|
| 740 |
+
index: the index to broadcast
|
| 741 |
+
padding: If padding was used, then tensor_input1[index] does not exist
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
|
| 745 |
+
"""
|
| 746 |
+
if tensor_input1[index] is None:
|
| 747 |
+
assert padding
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
if not padding:
|
| 751 |
+
# then the inputs are the same length so they all have dimensions at "index"
|
| 752 |
+
return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
|
| 753 |
+
BinConstraintD(res1[index], res2[index], op_eq),
|
| 754 |
+
BinConstraintD(res2[index], tensor_input2[index], op_eq)])
|
| 755 |
+
|
| 756 |
+
else:
|
| 757 |
+
# we don't set the input dimension to 1, since it doesn't exist.
|
| 758 |
+
return Conj([BinConstraintD(res1[index], res2[index], op_eq),
|
| 759 |
+
BinConstraintD(res2[index], tensor_input2[index], op_eq)])
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def apply_padding(e1_var: TVar,
|
| 763 |
+
e11: BinConstraintT,
|
| 764 |
+
e2: BinConstraintT,
|
| 765 |
+
e12: BinConstraintT,
|
| 766 |
+
d2: List[DVar],
|
| 767 |
+
d11: List[DVar],
|
| 768 |
+
d12: List[DVar],
|
| 769 |
+
counter: int):
|
| 770 |
+
"""
|
| 771 |
+
We are considering the possibility where one input has less dimensions than
|
| 772 |
+
another input, so we apply padding to the broadcasted results
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
e1_var: Variable representing the first input where padding will be
|
| 776 |
+
e11: constraint of the form e11 = Tensortype[d1, ..., dn]
|
| 777 |
+
e2: constraint of the form e2 = Tensortype[d1, ..., dn]
|
| 778 |
+
e12: constraint of the form e11 = Tensortype[d1, ..., dn]
|
| 779 |
+
d2: Tensor variables for the second input
|
| 780 |
+
d11: Tensor variables for the broadcasted first input
|
| 781 |
+
d12: Tensor variables for the broadcasted second input
|
| 782 |
+
counter: variable tracking
|
| 783 |
+
|
| 784 |
+
Returns: A new constraint whose goal is to apply padding to the broadcasted result
|
| 785 |
+
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
res = []
|
| 789 |
+
|
| 790 |
+
# pad the shorter input with None so we can pass it to the broadcasting helper function
|
| 791 |
+
for i in range(1, len(d2)):
|
| 792 |
+
|
| 793 |
+
d1, counter = gen_tensor_dims(i, counter)
|
| 794 |
+
|
| 795 |
+
nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
|
| 796 |
+
|
| 797 |
+
e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
|
| 798 |
+
|
| 799 |
+
simulate_padding = [None] * (len(d2) - i)
|
| 800 |
+
|
| 801 |
+
assert len(simulate_padding + d1) == len(d2)
|
| 802 |
+
|
| 803 |
+
broadcast_padding = []
|
| 804 |
+
|
| 805 |
+
# for every padding size, we also consider broadcasting
|
| 806 |
+
for j in range(len(d2) - i):
|
| 807 |
+
broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
|
| 808 |
+
|
| 809 |
+
# we consider the possibilities for broadcasting for every dimension. Since we already
|
| 810 |
+
# padded d1, we do not consider it while broadcasting
|
| 811 |
+
all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
|
| 812 |
+
d2[(len(d2) - i):],
|
| 813 |
+
d11[(len(d2) - i):],
|
| 814 |
+
d12[(len(d2) - i):])
|
| 815 |
+
# combine all constraints into a conjunction
|
| 816 |
+
c = Conj([e1, e11, e2, e12,
|
| 817 |
+
*broadcast_padding,
|
| 818 |
+
all_broadcasting_possibilities,
|
| 819 |
+
*nat_constraints
|
| 820 |
+
])
|
| 821 |
+
res.append(c)
|
| 822 |
+
|
| 823 |
+
return Disj(res), counter
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def no_broadcast_dim_with_index(d1: List[DVar],
|
| 827 |
+
d2: List[DVar],
|
| 828 |
+
d3: List[DVar],
|
| 829 |
+
d4: List[DVar],
|
| 830 |
+
i: int):
|
| 831 |
+
"""
|
| 832 |
+
Args:
|
| 833 |
+
d1: input 1
|
| 834 |
+
d2: input 2
|
| 835 |
+
d3: simulated broadcasting for input 1
|
| 836 |
+
d4: simulated broadcasting for input 2
|
| 837 |
+
i: the rank of the resulting tensor addition
|
| 838 |
+
|
| 839 |
+
Returns: Constraints for when no broadcasting occurs
|
| 840 |
+
"""
|
| 841 |
+
return Conj([
|
| 842 |
+
Disj([
|
| 843 |
+
Conj([BinConstraintD(d1[i], 1, op_eq),
|
| 844 |
+
BinConstraintD(d2[i], 1, op_eq)]),
|
| 845 |
+
|
| 846 |
+
Conj([BinConstraintD(d1[i], 1, op_neq),
|
| 847 |
+
BinConstraintD(d2[i], 1, op_neq)])]),
|
| 848 |
+
|
| 849 |
+
BinConstraintD(d1[i], d3[i], op_eq),
|
| 850 |
+
BinConstraintD(d2[i], d4[i], op_eq)])
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
|
| 855 |
+
"""
|
| 856 |
+
Generate lists of DVar to represent tensor dimensions
|
| 857 |
+
Args:
|
| 858 |
+
num_tensors: the required number of tensors
|
| 859 |
+
dim_size: the number of dimensions for each tensor
|
| 860 |
+
counter: variable tracking
|
| 861 |
+
|
| 862 |
+
Returns: A list of a list of tensor dimensions
|
| 863 |
+
|
| 864 |
+
"""
|
| 865 |
+
res = []
|
| 866 |
+
|
| 867 |
+
for _ in range(num_tensors):
|
| 868 |
+
dims, counter = gen_tensor_dims(dim_size, counter)
|
| 869 |
+
res.append(dims)
|
| 870 |
+
|
| 871 |
+
return res, counter
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def create_equality_constraints_for_broadcasting(e1: TVar,
|
| 875 |
+
e2: TVar,
|
| 876 |
+
e11: TVar,
|
| 877 |
+
e12: TVar,
|
| 878 |
+
d1: List[DVar],
|
| 879 |
+
d2: List[DVar],
|
| 880 |
+
d11: List[DVar],
|
| 881 |
+
d12: List[DVar]):
|
| 882 |
+
"""
|
| 883 |
+
Create equality constraints for when no broadcasting occurs
|
| 884 |
+
Args:
|
| 885 |
+
e1: Input 1
|
| 886 |
+
e2: Input 2
|
| 887 |
+
e11: Broadcasted input 1
|
| 888 |
+
e12: Broadcasted input 2
|
| 889 |
+
d1: Variables that store dimensions for e1
|
| 890 |
+
d2: Variables that store dimensions for e2
|
| 891 |
+
d11: Variables that store dimensions for e11
|
| 892 |
+
d12: Variables that store dimensions for e22
|
| 893 |
+
|
| 894 |
+
Returns: Four equality constraints
|
| 895 |
+
|
| 896 |
+
"""
|
| 897 |
+
|
| 898 |
+
e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
|
| 899 |
+
e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
|
| 900 |
+
e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
|
| 901 |
+
e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
|
| 902 |
+
return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def gen_consistency_constraints(constraint: Constraint, counter: int):
|
| 906 |
+
"""
|
| 907 |
+
Args:
|
| 908 |
+
constraint: Consistency constraint on tensors
|
| 909 |
+
counter: for variable tracking
|
| 910 |
+
|
| 911 |
+
Returns: Equality and consistency constraints on dimensions
|
| 912 |
+
|
| 913 |
+
"""
|
| 914 |
+
|
| 915 |
+
all_constraints = []
|
| 916 |
+
|
| 917 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 918 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 919 |
+
new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
|
| 920 |
+
|
| 921 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
|
| 922 |
+
|
| 923 |
+
c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
|
| 924 |
+
BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
|
| 925 |
+
[BinConstraintD(d1, d2, op_consistency) for
|
| 926 |
+
d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
|
| 927 |
+
|
| 928 |
+
all_constraints.append(c_tensor_i)
|
| 929 |
+
|
| 930 |
+
return all_constraints, counter
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
|
| 934 |
+
"""
|
| 935 |
+
Args:
|
| 936 |
+
constraint: Greatest upper bound on tensors
|
| 937 |
+
counter: variable tracking
|
| 938 |
+
|
| 939 |
+
Returns: A set of equality constraints and DGreatestUpperBound constraints
|
| 940 |
+
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
all_constraints = []
|
| 944 |
+
|
| 945 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 946 |
+
c = []
|
| 947 |
+
dims1, counter = gen_tensor_dims(i, counter)
|
| 948 |
+
c1tensor = TensorType(dims1)
|
| 949 |
+
|
| 950 |
+
dims2, counter = gen_tensor_dims(i, counter)
|
| 951 |
+
c2tensor = TensorType(dims2)
|
| 952 |
+
|
| 953 |
+
dims3, counter = gen_tensor_dims(i, counter)
|
| 954 |
+
c3tensor = TensorType(dims3)
|
| 955 |
+
|
| 956 |
+
c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
|
| 957 |
+
BinConstraintT(constraint.rhs2, c2tensor, op_eq),
|
| 958 |
+
BinConstraintT(constraint.res, c3tensor, op_eq)] + \
|
| 959 |
+
gen_nat_constraints(dims1 + dims2 + dims3)
|
| 960 |
+
|
| 961 |
+
assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
|
| 962 |
+
for i in range(len(c3tensor.__args__)):
|
| 963 |
+
c.append(DGreatestUpperBound(c3tensor.__args__[i],
|
| 964 |
+
c1tensor.__args__[i],
|
| 965 |
+
c2tensor.__args__[i]))
|
| 966 |
+
|
| 967 |
+
all_constraints.append(Conj(c))
|
| 968 |
+
return all_constraints, counter
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
|
| 972 |
+
"""
|
| 973 |
+
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
|
| 974 |
+
We look at all combinations for all dimensions in d1 and d2
|
| 975 |
+
Args:
|
| 976 |
+
d1: input1 dimensions
|
| 977 |
+
d2: input2 dimensions
|
| 978 |
+
d11: broadcasted input1 dimensions
|
| 979 |
+
d12: broadcasted input2 dimensions
|
| 980 |
+
|
| 981 |
+
Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
|
| 982 |
+
|
| 983 |
+
"""
|
| 984 |
+
|
| 985 |
+
size = len(d1)
|
| 986 |
+
|
| 987 |
+
res2 = []
|
| 988 |
+
|
| 989 |
+
for i in range(size):
|
| 990 |
+
t1 = broadcast_dim(d1, d2, d11, d12, i)
|
| 991 |
+
t2 = broadcast_dim(d2, d1, d12, d11, i)
|
| 992 |
+
t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
|
| 993 |
+
|
| 994 |
+
res2.append(Disj([t1, t2, t3]))
|
| 995 |
+
|
| 996 |
+
return Conj(res2)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
|
| 1000 |
+
"""
|
| 1001 |
+
Simulates broadcasting on e1 and e2 and returns the results
|
| 1002 |
+
respectively in e11 and e12. Because of gradual types,
|
| 1003 |
+
e1 and e2 may not be equal. Similarly, e11 and e12 may not
|
| 1004 |
+
be equal. e11 and e12 should be guaranteed to be consistent
|
| 1005 |
+
as they represent the shapes of the tensors to be added after
|
| 1006 |
+
broadcasting.
|
| 1007 |
+
Args:
|
| 1008 |
+
e1: TVar representing the type of input 1
|
| 1009 |
+
e2: TVar representing the type of input 2
|
| 1010 |
+
e11: TVar representing the representing broadcasted input 1
|
| 1011 |
+
e12: TVar representing the representing broadcasted input 2
|
| 1012 |
+
i: The rank of the resulting type of addition
|
| 1013 |
+
counter: for variable tracking
|
| 1014 |
+
|
| 1015 |
+
Returns: Simplified broadcasting constraints
|
| 1016 |
+
|
| 1017 |
+
"""
|
| 1018 |
+
dims, counter = gen_lists_of_dims(4, i, counter)
|
| 1019 |
+
[d1, d2, d3, d4] = dims
|
| 1020 |
+
nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
|
| 1021 |
+
|
| 1022 |
+
initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
|
| 1023 |
+
d1, d2, d3, d4)
|
| 1024 |
+
|
| 1025 |
+
[e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
|
| 1026 |
+
|
| 1027 |
+
# without padding, broadcast all possibilities for tensors of size i
|
| 1028 |
+
final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
|
| 1029 |
+
generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
|
| 1030 |
+
|
| 1031 |
+
# with padding, broadcast all possibilities for tensors of size i
|
| 1032 |
+
final_tensor_constraint_padding_arg1, counter = \
|
| 1033 |
+
apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
|
| 1034 |
+
|
| 1035 |
+
final_tensor_constraint_padding_arg2, counter = \
|
| 1036 |
+
apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
|
| 1037 |
+
|
| 1038 |
+
return final_tensor_constraint_no_padding, \
|
| 1039 |
+
final_tensor_constraint_padding_arg1, \
|
| 1040 |
+
final_tensor_constraint_padding_arg2, nat_dims_i, counter
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
|
| 2 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
|
| 3 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
|
| 4 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
|
| 5 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
|
| 6 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
|
| 7 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
|
| 8 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import z3 # type: ignore[import]
|
| 12 |
+
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
|
| 13 |
+
HAS_Z3 = True
|
| 14 |
+
|
| 15 |
+
def transform_to_z3(constraint, counter, dimension_dict):
|
| 16 |
+
if isinstance(constraint, Conj):
|
| 17 |
+
conjuncts = []
|
| 18 |
+
for c in constraint.conjucts:
|
| 19 |
+
new_c, counter = transform_to_z3(c, counter, dimension_dict)
|
| 20 |
+
conjuncts.append(new_c)
|
| 21 |
+
return z3.And(conjuncts), counter
|
| 22 |
+
|
| 23 |
+
elif isinstance(constraint, Disj):
|
| 24 |
+
disjuncts = []
|
| 25 |
+
for c in constraint.disjuncts:
|
| 26 |
+
new_c, counter = transform_to_z3(c, counter, dimension_dict)
|
| 27 |
+
disjuncts.append(new_c)
|
| 28 |
+
return z3.Or(disjuncts), counter
|
| 29 |
+
|
| 30 |
+
elif isinstance(constraint, T):
|
| 31 |
+
return True, counter
|
| 32 |
+
|
| 33 |
+
elif isinstance(constraint, F):
|
| 34 |
+
return False, counter
|
| 35 |
+
|
| 36 |
+
elif isinstance(constraint, BinConstraintT):
|
| 37 |
+
if constraint.op == op_eq:
|
| 38 |
+
lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
|
| 39 |
+
rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
|
| 40 |
+
return (lhs == rhs), counter
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
raise NotImplementedError('Method not yet implemented')
|
| 44 |
+
|
| 45 |
+
elif isinstance(constraint, BinConstraintD):
|
| 46 |
+
if constraint.op == op_eq:
|
| 47 |
+
|
| 48 |
+
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
|
| 49 |
+
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
|
| 50 |
+
transformed_lhs = z3.Bool(constraint.lhs.c)
|
| 51 |
+
return transformed_lhs == transformed_rhs, counter
|
| 52 |
+
|
| 53 |
+
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
|
| 54 |
+
# with dimension transformations we consider the encoding
|
| 55 |
+
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
| 56 |
+
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
| 57 |
+
return lhs == rhs, counter
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
# then we have an algebraic expression which means that we disregard the
|
| 61 |
+
# first element of the encoding
|
| 62 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 63 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 64 |
+
return lhs == rhs, counter
|
| 65 |
+
|
| 66 |
+
# The assumption here is that the LHS and RHS must be dimensions
|
| 67 |
+
elif constraint.op == op_neq:
|
| 68 |
+
assert is_dim(constraint.lhs)
|
| 69 |
+
assert is_dim(constraint.rhs)
|
| 70 |
+
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
| 71 |
+
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
| 72 |
+
if constraint.rhs == Dyn or constraint.lhs == Dyn:
|
| 73 |
+
if constraint.rhs == Dyn:
|
| 74 |
+
return lhs.arg(0) == 1, counter
|
| 75 |
+
elif constraint.lhs == Dyn:
|
| 76 |
+
return rhs.arg(0) == 1, counter
|
| 77 |
+
|
| 78 |
+
# if one of the instances is a number
|
| 79 |
+
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
|
| 80 |
+
if isinstance(constraint.lhs, int):
|
| 81 |
+
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 82 |
+
|
| 83 |
+
elif isinstance(constraint.rhs, int):
|
| 84 |
+
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 85 |
+
|
| 86 |
+
else:
|
| 87 |
+
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
|
| 88 |
+
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
|
| 89 |
+
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
elif constraint.op == op_leq:
|
| 93 |
+
# if the dimensions are not dyn, this will come into effect
|
| 94 |
+
# there would have been another constraint specifying if a given dimension
|
| 95 |
+
# is dyn or not
|
| 96 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 97 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 98 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 99 |
+
return lhs <= rhs, counter
|
| 100 |
+
|
| 101 |
+
elif constraint.op == op_gt:
|
| 102 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 103 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 104 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 105 |
+
return lhs > rhs, counter
|
| 106 |
+
|
| 107 |
+
elif constraint.op == op_lt:
|
| 108 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 109 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 110 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 111 |
+
return lhs < rhs, counter
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
raise NotImplementedError('operation not yet implemented')
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
raise NotImplementedError('Operation not yet implemented')
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def transform_var(tensor, counter, dimension_dict):
|
| 121 |
+
"""
|
| 122 |
+
Transforms tensor variables to a format understood by z3
|
| 123 |
+
Args:
|
| 124 |
+
tensor: Tensor variable or a tensor type potentially with variable dimensions
|
| 125 |
+
Returns: Transformed variable to a z3 format
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
if isinstance(tensor, TensorType):
|
| 129 |
+
res = []
|
| 130 |
+
for t in tensor.__args__:
|
| 131 |
+
transformed, counter = transform_dimension(t, counter, dimension_dict)
|
| 132 |
+
res.append(transformed)
|
| 133 |
+
|
| 134 |
+
assert len(res) <= 4
|
| 135 |
+
if len(tensor.__args__) == 1:
|
| 136 |
+
return tensor_type.tensor1(res[0]), counter
|
| 137 |
+
elif len(tensor.__args__) == 2:
|
| 138 |
+
return tensor_type.tensor2(res[0], res[1]), counter
|
| 139 |
+
elif len(tensor.__args__) == 3:
|
| 140 |
+
return tensor_type.tensor3(res[0], res[1], res[2]), counter
|
| 141 |
+
elif len(tensor.__args__) == 4:
|
| 142 |
+
return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
|
| 143 |
+
|
| 144 |
+
elif tensor == Dyn:
|
| 145 |
+
return z3_dyn, counter
|
| 146 |
+
|
| 147 |
+
elif isinstance(tensor, TVar):
|
| 148 |
+
return z3.Const(tensor.tvar, tensor_type), counter
|
| 149 |
+
|
| 150 |
+
def transform_dimension(dimension, counter, dimension_dict):
|
| 151 |
+
"""
|
| 152 |
+
Takes a dimension variable or a number and transforms it to a tuple
|
| 153 |
+
according to our scheme
|
| 154 |
+
Args:
|
| 155 |
+
dimension: The dimension to be transformed
|
| 156 |
+
counter: variable tracking
|
| 157 |
+
|
| 158 |
+
Returns: tuple and the current counter
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
+
if dimension == Dyn:
|
| 162 |
+
counter += 1
|
| 163 |
+
return D(0, z3.Int(counter)), counter
|
| 164 |
+
elif isinstance(dimension, int):
|
| 165 |
+
return D(1, dimension), counter
|
| 166 |
+
elif isinstance(dimension, DVar):
|
| 167 |
+
if dimension.c in dimension_dict:
|
| 168 |
+
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
|
| 169 |
+
else:
|
| 170 |
+
counter += 1
|
| 171 |
+
dimension_dict[dimension.c] = counter
|
| 172 |
+
return D(z3.Int(counter), z3.Int(dimension.c)), counter
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def transform_algebraic_expression(expr, counter, dimension_dict):
|
| 176 |
+
"""
|
| 177 |
+
Transforms an algebraic expression to z3 format
|
| 178 |
+
Args:
|
| 179 |
+
expr: An expression is either a dimension variable or an algebraic-expression
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
Returns: the transformed expression
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
assert is_algebraic_expression(expr) or is_dim(expr)
|
| 186 |
+
|
| 187 |
+
if is_dim(expr):
|
| 188 |
+
transformed, counter = transform_dimension(expr, counter, dimension_dict)
|
| 189 |
+
return transformed.arg(1), counter
|
| 190 |
+
|
| 191 |
+
elif isinstance(expr, Prod):
|
| 192 |
+
|
| 193 |
+
dims = []
|
| 194 |
+
for dim in expr.products:
|
| 195 |
+
assert is_dim(dim)
|
| 196 |
+
d, counter = transform_dimension(dim, counter, dimension_dict)
|
| 197 |
+
dims.append(d.arg(1))
|
| 198 |
+
return z3.Product(dims), counter
|
| 199 |
+
|
| 200 |
+
elif is_algebraic_expression(expr):
|
| 201 |
+
|
| 202 |
+
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
|
| 203 |
+
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
|
| 204 |
+
|
| 205 |
+
if expr.op == op_sub:
|
| 206 |
+
c = lhs - rhs
|
| 207 |
+
|
| 208 |
+
elif expr.op == op_add:
|
| 209 |
+
c = lhs + rhs
|
| 210 |
+
|
| 211 |
+
elif expr.op == op_div:
|
| 212 |
+
c = lhs / rhs
|
| 213 |
+
|
| 214 |
+
elif expr.op == op_mul:
|
| 215 |
+
c = lhs * rhs
|
| 216 |
+
|
| 217 |
+
elif expr.op == op_mod:
|
| 218 |
+
c = lhs % rhs
|
| 219 |
+
|
| 220 |
+
else:
|
| 221 |
+
raise NotImplementedError('operation not yet implemented')
|
| 222 |
+
|
| 223 |
+
return c, counter
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
raise RuntimeError
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def transform_all_constraints(traced, counter=0):
|
| 230 |
+
"""
|
| 231 |
+
Given a trace, generates constraints and transforms them to z3 format
|
| 232 |
+
|
| 233 |
+
"""
|
| 234 |
+
dimension_dict = {} # type: ignore[var-annotated]
|
| 235 |
+
|
| 236 |
+
generator = ConstraintGenerator(traced)
|
| 237 |
+
new_constraints, counter = generator.generate_constraints(counter)
|
| 238 |
+
|
| 239 |
+
# print(new_constraints.conjucts[0])
|
| 240 |
+
# print(*new_constraints.conjucts, sep='\n')
|
| 241 |
+
|
| 242 |
+
# transform precision, matching, consistency till obtaining a fixed point
|
| 243 |
+
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
|
| 244 |
+
# print(new_constraints)
|
| 245 |
+
# print(new_constraints.conjucts)
|
| 246 |
+
# new_constraints.conjucts = new_constraints.conjucts[:-1]
|
| 247 |
+
# print(*new_constraints.conjucts, sep='\n')
|
| 248 |
+
|
| 249 |
+
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
|
| 250 |
+
# print(transformed)
|
| 251 |
+
return transformed
|
| 252 |
+
|
| 253 |
+
def iterate_till_fixed_point(constraints, counter):
|
| 254 |
+
"""
|
| 255 |
+
Transform constraints till reaching a fixed point
|
| 256 |
+
"""
|
| 257 |
+
old_c = None
|
| 258 |
+
while old_c != constraints:
|
| 259 |
+
old_c = constraints
|
| 260 |
+
constraints, counter = transform_constraint(constraints, counter)
|
| 261 |
+
return constraints, counter
|
| 262 |
+
|
| 263 |
+
def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
|
| 264 |
+
"""
|
| 265 |
+
Takes a node and a graph and generates two sets of constraints.
|
| 266 |
+
One set constraints the node's constraints and another set
|
| 267 |
+
constraints the negation of the node's constraints
|
| 268 |
+
Args:
|
| 269 |
+
tracer_root: the root for getting the module instances
|
| 270 |
+
graph: the graph so far in the tracing process
|
| 271 |
+
node: node that represents a conditional
|
| 272 |
+
counter: variable tracking
|
| 273 |
+
|
| 274 |
+
Returns: Two sets of constraints. One with a conjunction with the
|
| 275 |
+
the conditional constraint and the other with a conjunction with
|
| 276 |
+
its negation.
|
| 277 |
+
|
| 278 |
+
"""
|
| 279 |
+
dimension_dict = {} # type: ignore[var-annotated]
|
| 280 |
+
|
| 281 |
+
generator = ConstraintGenerator(tracer_root, graph)
|
| 282 |
+
new_constraints, counter = generator.generate_constraints(counter)
|
| 283 |
+
|
| 284 |
+
condition_constraint = new_constraints.conjucts[-1]
|
| 285 |
+
|
| 286 |
+
# we know the constraint is a conjunction where the last constraint is about the conditional
|
| 287 |
+
# so remove the last constraint
|
| 288 |
+
new_constraints.conjucts = new_constraints.conjucts[:-1]
|
| 289 |
+
|
| 290 |
+
# transform precision, matching, consistency till obtaining a fixed point
|
| 291 |
+
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# since the function returns a list of one element, we get the first element
|
| 295 |
+
# we are only interested in the RHS in this case because the LHS just stores
|
| 296 |
+
# the result
|
| 297 |
+
|
| 298 |
+
# we make sure the constraint is of the form:
|
| 299 |
+
# c = b where b is a boolean expression
|
| 300 |
+
# and we consider b (constraint.rhs) for transformation
|
| 301 |
+
assert isinstance(condition_constraint.lhs, BVar)
|
| 302 |
+
assert is_bool_expr(condition_constraint.rhs)
|
| 303 |
+
condition_constraint_rhs = condition_constraint.rhs
|
| 304 |
+
|
| 305 |
+
# transform the condition constraint
|
| 306 |
+
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
|
| 307 |
+
|
| 308 |
+
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
|
| 309 |
+
|
| 310 |
+
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
|
| 311 |
+
|
| 312 |
+
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
|
| 313 |
+
|
| 314 |
+
return z3.And([transformed, transformed_condition_constraint]), \
|
| 315 |
+
z3.And([transformed, negation_transformed_condition_constraint])
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
|
| 319 |
+
"""
|
| 320 |
+
Given an IR and a node representing a conditional, evaluate the conditional
|
| 321 |
+
and its negation
|
| 322 |
+
Args:
|
| 323 |
+
tracer_root: Tracer root for module instances
|
| 324 |
+
node: The node to be evaluated
|
| 325 |
+
|
| 326 |
+
Returns: the results of evaluating the condition and the negation with
|
| 327 |
+
the rest of the constraints
|
| 328 |
+
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
transformed_positive, transformed_negative = \
|
| 332 |
+
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
|
| 333 |
+
|
| 334 |
+
s = z3.Solver()
|
| 335 |
+
s.add(transformed_positive)
|
| 336 |
+
if user_constraints is not None:
|
| 337 |
+
s.add(user_constraints)
|
| 338 |
+
condition = s.check()
|
| 339 |
+
|
| 340 |
+
s = z3.Solver()
|
| 341 |
+
s.add(transformed_negative)
|
| 342 |
+
if user_constraints is not None:
|
| 343 |
+
s.add(user_constraints)
|
| 344 |
+
negation = s.check()
|
| 345 |
+
return condition, negation
|
| 346 |
+
|
| 347 |
+
except ImportError:
|
| 348 |
+
HAS_Z3 = False
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (492 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
|
| 4 |
+
|
| 5 |
+
def raises(err, lamda):
|
| 6 |
+
try:
|
| 7 |
+
lamda()
|
| 8 |
+
return False
|
| 9 |
+
except err:
|
| 10 |
+
return True
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def expand_tuples(L):
|
| 14 |
+
"""
|
| 15 |
+
>>> expand_tuples([1, (2, 3)])
|
| 16 |
+
[(1, 2), (1, 3)]
|
| 17 |
+
>>> expand_tuples([1, 2])
|
| 18 |
+
[(1, 2)]
|
| 19 |
+
"""
|
| 20 |
+
if not L:
|
| 21 |
+
return [()]
|
| 22 |
+
elif not isinstance(L[0], tuple):
|
| 23 |
+
rest = expand_tuples(L[1:])
|
| 24 |
+
return [(L[0],) + t for t in rest]
|
| 25 |
+
else:
|
| 26 |
+
rest = expand_tuples(L[1:])
|
| 27 |
+
return [(item,) + t for t in rest for item in L[0]]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Taken from theano/theano/gof/sched.py
|
| 31 |
+
# Avoids licensing issues because this was written by Matthew Rocklin
|
| 32 |
+
def _toposort(edges):
|
| 33 |
+
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
| 34 |
+
inputs:
|
| 35 |
+
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
| 36 |
+
outputs:
|
| 37 |
+
L - an ordered list of nodes that satisfy the dependencies of edges
|
| 38 |
+
>>> _toposort({1: (2, 3), 2: (3, )})
|
| 39 |
+
[1, 2, 3]
|
| 40 |
+
>>> # Closely follows the wikipedia page [2]
|
| 41 |
+
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
| 42 |
+
>>> # Communications of the ACM
|
| 43 |
+
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
| 44 |
+
"""
|
| 45 |
+
incoming_edges = reverse_dict(edges)
|
| 46 |
+
incoming_edges = OrderedDict((k, set(val))
|
| 47 |
+
for k, val in incoming_edges.items())
|
| 48 |
+
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
|
| 49 |
+
L = []
|
| 50 |
+
|
| 51 |
+
while S:
|
| 52 |
+
n, _ = S.popitem()
|
| 53 |
+
L.append(n)
|
| 54 |
+
for m in edges.get(n, ()):
|
| 55 |
+
assert n in incoming_edges[m]
|
| 56 |
+
incoming_edges[m].remove(n)
|
| 57 |
+
if not incoming_edges[m]:
|
| 58 |
+
S[m] = None
|
| 59 |
+
if any(incoming_edges.get(v, None) for v in edges):
|
| 60 |
+
raise ValueError("Input has cycles")
|
| 61 |
+
return L
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def reverse_dict(d):
|
| 65 |
+
"""Reverses direction of dependence dict
|
| 66 |
+
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
| 67 |
+
>>> reverse_dict(d) # doctest: +SKIP
|
| 68 |
+
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
| 69 |
+
:note: dict order are not deterministic. As we iterate on the
|
| 70 |
+
input dict, it make the output of this function depend on the
|
| 71 |
+
dict order. So this function output order should be considered
|
| 72 |
+
as undeterministic.
|
| 73 |
+
"""
|
| 74 |
+
result = OrderedDict() # type: ignore[var-annotated]
|
| 75 |
+
for key in d:
|
| 76 |
+
for val in d[key]:
|
| 77 |
+
result[val] = result.get(val, tuple()) + (key, )
|
| 78 |
+
return result
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Taken from toolz
|
| 82 |
+
# Avoids licensing issues because this version was authored by Matthew Rocklin
|
| 83 |
+
def groupby(func, seq):
|
| 84 |
+
""" Group a collection by a key function
|
| 85 |
+
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
| 86 |
+
>>> groupby(len, names) # doctest: +SKIP
|
| 87 |
+
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
| 88 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 89 |
+
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
| 90 |
+
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
| 91 |
+
See Also:
|
| 92 |
+
``countby``
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
d = OrderedDict() # type: ignore[var-annotated]
|
| 96 |
+
for item in seq:
|
| 97 |
+
key = func(item)
|
| 98 |
+
if key not in d:
|
| 99 |
+
d[key] = list()
|
| 100 |
+
d[key].append(item)
|
| 101 |
+
return d
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def typename(type):
|
| 105 |
+
"""Get the name of `type`.
|
| 106 |
+
Parameters
|
| 107 |
+
----------
|
| 108 |
+
type : Union[Type, Tuple[Type]]
|
| 109 |
+
Returns
|
| 110 |
+
-------
|
| 111 |
+
str
|
| 112 |
+
The name of `type` or a tuple of the names of the types in `type`.
|
| 113 |
+
Examples
|
| 114 |
+
--------
|
| 115 |
+
>>> typename(int)
|
| 116 |
+
'int'
|
| 117 |
+
>>> typename((int, float))
|
| 118 |
+
'(int, float)'
|
| 119 |
+
"""
|
| 120 |
+
try:
|
| 121 |
+
return type.__name__
|
| 122 |
+
except AttributeError:
|
| 123 |
+
if len(type) == 1:
|
| 124 |
+
return typename(*type)
|
| 125 |
+
return f"({', '.join(map(typename, type))})"
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc
ADDED
|
Binary file (30.5 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
| 7 |
+
"""
|
| 8 |
+
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
| 9 |
+
If sequence node is not annotated with a type, do nothing.
|
| 10 |
+
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
|
| 11 |
+
|
| 12 |
+
This is helpful since annotations on local names within function are lost during FX transforms.
|
| 13 |
+
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
graph (Graph): The graph to be annotated
|
| 17 |
+
"""
|
| 18 |
+
for node in graph.nodes:
|
| 19 |
+
if node.target == operator.getitem:
|
| 20 |
+
sequence_node, index_node = node.args
|
| 21 |
+
if not sequence_node.type:
|
| 22 |
+
continue
|
| 23 |
+
# container types
|
| 24 |
+
if hasattr(sequence_node.type, "_name"):
|
| 25 |
+
parameterized_types = sequence_node.type.__args__
|
| 26 |
+
if sequence_node.type._name == "Tuple":
|
| 27 |
+
if len(parameterized_types) == 2 and isinstance(
|
| 28 |
+
parameterized_types[1], type(...)
|
| 29 |
+
):
|
| 30 |
+
node.type = parameterized_types[0]
|
| 31 |
+
else:
|
| 32 |
+
assert len(parameterized_types) > index_node
|
| 33 |
+
node_type = parameterized_types[index_node]
|
| 34 |
+
node.type = node_type
|
| 35 |
+
elif sequence_node.type._name == "List":
|
| 36 |
+
assert len(parameterized_types) == 1
|
| 37 |
+
node.type = parameterized_types[0]
|
| 38 |
+
# NamedTuple type
|
| 39 |
+
elif hasattr(sequence_node.type, "__annotations__"):
|
| 40 |
+
if sequence_node.type == torch.Tensor:
|
| 41 |
+
continue
|
| 42 |
+
sequence_node_field_types = sequence_node.type.__annotations__
|
| 43 |
+
field_name = sequence_node.type._fields[index_node]
|
| 44 |
+
node.type = sequence_node_field_types[field_name]
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py
ADDED
|
File without changes
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple, Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
| 5 |
+
from torch.utils._pytree import tree_flatten
|
| 6 |
+
|
| 7 |
+
from torch.fx import GraphModule, Graph
|
| 8 |
+
from torch.fx import Node
|
| 9 |
+
|
| 10 |
+
aten = torch.ops.aten
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# stateful ops are banned from CSE
|
| 14 |
+
rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950
|
| 15 |
+
|
| 16 |
+
inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
|
| 20 |
+
def get_CSE_banned_ops():
|
| 21 |
+
return rand_ops.union(inplace_ops)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
|
| 25 |
+
class CSEPass(PassBase):
|
| 26 |
+
|
| 27 |
+
def __init__(self, banned_ops=None):
|
| 28 |
+
"""
|
| 29 |
+
This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
|
| 30 |
+
|
| 31 |
+
For functional dialects, user would only need to specify the random ops in ban list.
|
| 32 |
+
|
| 33 |
+
Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
|
| 34 |
+
If your dialect contains stateful operators, please customized the banned_ops.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
if banned_ops is None:
|
| 38 |
+
banned_ops = set()
|
| 39 |
+
self.banned_ops = banned_ops
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
def call(self, graph_module: GraphModule) -> PassResult:
|
| 43 |
+
"""
|
| 44 |
+
Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
|
| 45 |
+
|
| 46 |
+
Example usage:
|
| 47 |
+
|
| 48 |
+
from torch.fx.experimental.proxy_tensor import make_fx
|
| 49 |
+
def f(a):
|
| 50 |
+
b = a * a
|
| 51 |
+
c = a * a
|
| 52 |
+
return b+c
|
| 53 |
+
|
| 54 |
+
p = CSEPass()
|
| 55 |
+
traced_graph = make_fx(f)(torch.tensor(1))
|
| 56 |
+
print(traced_graph)
|
| 57 |
+
result = p(traced_graph)
|
| 58 |
+
print(result.graph_module)
|
| 59 |
+
"""
|
| 60 |
+
def get_aten_target(node):
|
| 61 |
+
if hasattr(node.target, 'overloadpacket'):
|
| 62 |
+
return node.target.overloadpacket
|
| 63 |
+
return node.target
|
| 64 |
+
|
| 65 |
+
modified = False
|
| 66 |
+
new_graph = Graph()
|
| 67 |
+
env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
|
| 68 |
+
hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
|
| 69 |
+
token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
|
| 70 |
+
for n in graph_module.graph.nodes:
|
| 71 |
+
# The placeholder, output, and get_attr nodes are copied to the new graph without change
|
| 72 |
+
# do not CSE away random operations
|
| 73 |
+
if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
|
| 74 |
+
new_node = new_graph.node_copy(n, lambda x: env[x])
|
| 75 |
+
env[n] = new_node
|
| 76 |
+
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
|
| 77 |
+
# substitute args and kwargs members to their mapping in env if exists
|
| 78 |
+
# specs can be used to reconstruct nested list/dictionaries
|
| 79 |
+
def substitute(arg_list):
|
| 80 |
+
arg_list, spec = tree_flatten(arg_list)
|
| 81 |
+
for i in range(len(arg_list)):
|
| 82 |
+
v = arg_list[i]
|
| 83 |
+
if isinstance(v, Node) and v in env:
|
| 84 |
+
arg_list[i] = env[v]
|
| 85 |
+
return tuple(arg_list), spec
|
| 86 |
+
args, args_spec = substitute(n.args)
|
| 87 |
+
kwargs, kwargs_spec = substitute(n.kwargs)
|
| 88 |
+
|
| 89 |
+
# each token corresponds to a unique node
|
| 90 |
+
# nodes with the same token can be substituted
|
| 91 |
+
token = {"target": n.target, "args": args, "args_spec": args_spec,
|
| 92 |
+
"kwargs": kwargs, "kwargs_spec": kwargs_spec}
|
| 93 |
+
|
| 94 |
+
# hash substituted args to a number, do not hash specs because specs are not hashable
|
| 95 |
+
hash_arg = hash((args, kwargs))
|
| 96 |
+
hash_val = (n.target, hash_arg)
|
| 97 |
+
|
| 98 |
+
# check if a node has a substitute and can be eliminated
|
| 99 |
+
hash_val_in_hash_env = hash_val in hash_env
|
| 100 |
+
if hash_val_in_hash_env and token_map[hash_val] == token:
|
| 101 |
+
modified = True # substitution happens and the graph is modified
|
| 102 |
+
env[n] = hash_env[hash_val]
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
new_node = new_graph.node_copy(n, lambda x: env[x])
|
| 106 |
+
env[n] = new_node
|
| 107 |
+
if not hash_val_in_hash_env:
|
| 108 |
+
hash_env[hash_val] = new_node
|
| 109 |
+
token_map[hash_val] = token
|
| 110 |
+
|
| 111 |
+
csed_gm = GraphModule(graph_module, new_graph)
|
| 112 |
+
return PassResult(csed_gm, modified)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch.fx
|
| 4 |
+
from torch.fx import Node
|
| 5 |
+
from torch.fx._compatibility import compatibility
|
| 6 |
+
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
| 7 |
+
from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake
|
| 8 |
+
from torch.fx.node import map_aggregate
|
| 9 |
+
|
| 10 |
+
__all__ = ['FakeTensorProp']
|
| 11 |
+
|
| 12 |
+
@compatibility(is_backward_compatible=False)
|
| 13 |
+
class FakeTensorProp(torch.fx.Interpreter):
|
| 14 |
+
"""
|
| 15 |
+
Execute an FX graph Node-by-Node and record a fake tensor representing
|
| 16 |
+
the metadata for the node. Unlike ShapeProp, (1) this propagation
|
| 17 |
+
is cheap--it does the propagation with meta tensors which do not actually
|
| 18 |
+
store data, and (2) the fake tensors have much more fine grained information,
|
| 19 |
+
e.g., they have accurate alias information that can be consulted by looking
|
| 20 |
+
at the storages.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
module (GraphModule): The module to be executed
|
| 24 |
+
mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node.
|
| 25 |
+
"""
|
| 26 |
+
def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None):
|
| 27 |
+
super().__init__(module)
|
| 28 |
+
if mode is None:
|
| 29 |
+
mode = FakeTensorMode()
|
| 30 |
+
self._mode = mode
|
| 31 |
+
|
| 32 |
+
def run_node(self, n: Node):
|
| 33 |
+
import sympy
|
| 34 |
+
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
| 35 |
+
|
| 36 |
+
result = super().run_node(n)
|
| 37 |
+
sym = None
|
| 38 |
+
if (
|
| 39 |
+
'val' in n.meta and
|
| 40 |
+
isinstance(v := n.meta['val'], torch.SymInt) and
|
| 41 |
+
isinstance(v.node.expr, sympy.Symbol) and free_unbacked_symbols(v)
|
| 42 |
+
):
|
| 43 |
+
sym = v
|
| 44 |
+
|
| 45 |
+
def extract_val(obj):
|
| 46 |
+
if isinstance(obj, FakeTensor):
|
| 47 |
+
return snapshot_fake(obj)
|
| 48 |
+
elif isinstance(obj, torch.Tensor):
|
| 49 |
+
# TODO: How is it possible that we get a non fake tensor? We
|
| 50 |
+
# should be running under the mode...
|
| 51 |
+
return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
|
| 52 |
+
elif isinstance(obj, py_sym_types):
|
| 53 |
+
return obj
|
| 54 |
+
else:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
meta = map_aggregate(result, extract_val)
|
| 58 |
+
if meta is not None:
|
| 59 |
+
n.meta['val'] = meta
|
| 60 |
+
if sym is not None:
|
| 61 |
+
torch._check(meta == v)
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
def propagate(self, *args):
|
| 65 |
+
fake_args = [
|
| 66 |
+
self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a
|
| 67 |
+
for a in args
|
| 68 |
+
]
|
| 69 |
+
return self.propagate_dont_convert_inputs(*fake_args)
|
| 70 |
+
|
| 71 |
+
def propagate_dont_convert_inputs(self, *args):
|
| 72 |
+
with self._mode:
|
| 73 |
+
return super().run(*args)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import hashlib
|
| 3 |
+
import torch
|
| 4 |
+
import torch.fx
|
| 5 |
+
from typing import Any, Dict, Optional, TYPE_CHECKING
|
| 6 |
+
from torch.fx.node import _get_qualified_name, _format_arg
|
| 7 |
+
from torch.fx.graph import _parse_stack_trace
|
| 8 |
+
from torch.fx.passes.shape_prop import TensorMetadata
|
| 9 |
+
from torch.fx._compatibility import compatibility
|
| 10 |
+
from itertools import chain
|
| 11 |
+
|
| 12 |
+
__all__ = ['FxGraphDrawer']
|
| 13 |
+
try:
|
| 14 |
+
import pydot
|
| 15 |
+
HAS_PYDOT = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
HAS_PYDOT = False
|
| 18 |
+
|
| 19 |
+
_COLOR_MAP = {
|
| 20 |
+
"placeholder": '"AliceBlue"',
|
| 21 |
+
"call_module": "LemonChiffon1",
|
| 22 |
+
"get_param": "Yellow2",
|
| 23 |
+
"get_attr": "LightGrey",
|
| 24 |
+
"output": "PowderBlue",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
_HASH_COLOR_MAP = [
|
| 28 |
+
"CadetBlue1",
|
| 29 |
+
"Coral",
|
| 30 |
+
"DarkOliveGreen1",
|
| 31 |
+
"DarkSeaGreen1",
|
| 32 |
+
"GhostWhite",
|
| 33 |
+
"Khaki1",
|
| 34 |
+
"LavenderBlush1",
|
| 35 |
+
"LightSkyBlue",
|
| 36 |
+
"MistyRose1",
|
| 37 |
+
"MistyRose2",
|
| 38 |
+
"PaleTurquoise2",
|
| 39 |
+
"PeachPuff1",
|
| 40 |
+
"Salmon",
|
| 41 |
+
"Thistle1",
|
| 42 |
+
"Thistle3",
|
| 43 |
+
"Wheat1",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
_WEIGHT_TEMPLATE = {
|
| 47 |
+
"fillcolor": "Salmon",
|
| 48 |
+
"style": '"filled,rounded"',
|
| 49 |
+
"fontcolor": "#000000",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
if HAS_PYDOT:
|
| 53 |
+
@compatibility(is_backward_compatible=False)
|
| 54 |
+
class FxGraphDrawer:
|
| 55 |
+
"""
|
| 56 |
+
Visualize a torch.fx.Graph with graphviz
|
| 57 |
+
Basic usage:
|
| 58 |
+
g = FxGraphDrawer(symbolic_traced, "resnet18")
|
| 59 |
+
g.get_dot_graph().write_svg("a.svg")
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
graph_module: torch.fx.GraphModule,
|
| 65 |
+
name: str,
|
| 66 |
+
ignore_getattr: bool = False,
|
| 67 |
+
ignore_parameters_and_buffers: bool = False,
|
| 68 |
+
skip_node_names_in_args: bool = True,
|
| 69 |
+
parse_stack_trace: bool = False,
|
| 70 |
+
dot_graph_shape: Optional[str] = None,
|
| 71 |
+
):
|
| 72 |
+
self._name = name
|
| 73 |
+
self.dot_graph_shape = (
|
| 74 |
+
dot_graph_shape if dot_graph_shape is not None else "record"
|
| 75 |
+
)
|
| 76 |
+
_WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
|
| 77 |
+
|
| 78 |
+
self._dot_graphs = {
|
| 79 |
+
name: self._to_dot(
|
| 80 |
+
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace
|
| 81 |
+
)
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
for node in graph_module.graph.nodes:
|
| 85 |
+
if node.op != "call_module":
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
leaf_node = self._get_leaf_node(graph_module, node)
|
| 89 |
+
|
| 90 |
+
if not isinstance(leaf_node, torch.fx.GraphModule):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
|
| 95 |
+
leaf_node,
|
| 96 |
+
f"{name}_{node.target}",
|
| 97 |
+
ignore_getattr,
|
| 98 |
+
ignore_parameters_and_buffers,
|
| 99 |
+
skip_node_names_in_args,
|
| 100 |
+
parse_stack_trace,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def get_dot_graph(self, submod_name=None) -> pydot.Dot:
|
| 104 |
+
"""
|
| 105 |
+
Visualize a torch.fx.Graph with graphviz
|
| 106 |
+
Example:
|
| 107 |
+
>>> # xdoctest: +REQUIRES(module:pydot)
|
| 108 |
+
>>> # define module
|
| 109 |
+
>>> class MyModule(torch.nn.Module):
|
| 110 |
+
>>> def __init__(self):
|
| 111 |
+
>>> super().__init__()
|
| 112 |
+
>>> self.linear = torch.nn.Linear(4, 5)
|
| 113 |
+
>>> def forward(self, x):
|
| 114 |
+
>>> return self.linear(x).clamp(min=0.0, max=1.0)
|
| 115 |
+
>>> module = MyModule()
|
| 116 |
+
>>> # trace the module
|
| 117 |
+
>>> symbolic_traced = torch.fx.symbolic_trace(module)
|
| 118 |
+
>>> # setup output file
|
| 119 |
+
>>> import ubelt as ub
|
| 120 |
+
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir()
|
| 121 |
+
>>> fpath = dpath / 'linear.svg'
|
| 122 |
+
>>> # draw the graph
|
| 123 |
+
>>> g = FxGraphDrawer(symbolic_traced, "linear")
|
| 124 |
+
>>> g.get_dot_graph().write_svg(fpath)
|
| 125 |
+
"""
|
| 126 |
+
if submod_name is None:
|
| 127 |
+
return self.get_main_dot_graph()
|
| 128 |
+
else:
|
| 129 |
+
return self.get_submod_dot_graph(submod_name)
|
| 130 |
+
|
| 131 |
+
def get_main_dot_graph(self) -> pydot.Dot:
|
| 132 |
+
return self._dot_graphs[self._name]
|
| 133 |
+
|
| 134 |
+
def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
|
| 135 |
+
return self._dot_graphs[f"{self._name}_{submod_name}"]
|
| 136 |
+
|
| 137 |
+
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
|
| 138 |
+
return self._dot_graphs
|
| 139 |
+
|
| 140 |
+
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
|
| 141 |
+
|
| 142 |
+
template = {
|
| 143 |
+
"shape": self.dot_graph_shape,
|
| 144 |
+
"fillcolor": "#CAFFE3",
|
| 145 |
+
"style": '"filled,rounded"',
|
| 146 |
+
"fontcolor": "#000000",
|
| 147 |
+
}
|
| 148 |
+
if node.op in _COLOR_MAP:
|
| 149 |
+
template["fillcolor"] = _COLOR_MAP[node.op]
|
| 150 |
+
else:
|
| 151 |
+
# Use a random color for each node; based on its name so it's stable.
|
| 152 |
+
target_name = node._pretty_print_target(node.target)
|
| 153 |
+
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
|
| 154 |
+
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
|
| 155 |
+
return template
|
| 156 |
+
|
| 157 |
+
def _get_leaf_node(
|
| 158 |
+
self, module: torch.nn.Module, node: torch.fx.Node
|
| 159 |
+
) -> torch.nn.Module:
|
| 160 |
+
py_obj = module
|
| 161 |
+
assert isinstance(node.target, str)
|
| 162 |
+
atoms = node.target.split(".")
|
| 163 |
+
for atom in atoms:
|
| 164 |
+
if not hasattr(py_obj, atom):
|
| 165 |
+
raise RuntimeError(
|
| 166 |
+
str(py_obj) + " does not have attribute " + atom + "!"
|
| 167 |
+
)
|
| 168 |
+
py_obj = getattr(py_obj, atom)
|
| 169 |
+
return py_obj
|
| 170 |
+
|
| 171 |
+
def _typename(self, target: Any) -> str:
|
| 172 |
+
if isinstance(target, torch.nn.Module):
|
| 173 |
+
ret = torch.typename(target)
|
| 174 |
+
elif isinstance(target, str):
|
| 175 |
+
ret = target
|
| 176 |
+
else:
|
| 177 |
+
ret = _get_qualified_name(target)
|
| 178 |
+
|
| 179 |
+
# Escape "{" and "}" to prevent dot files like:
|
| 180 |
+
# https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
|
| 181 |
+
# which triggers `Error: bad label format (...)` from dot
|
| 182 |
+
return ret.replace("{", r"\{").replace("}", r"\}")
|
| 183 |
+
|
| 184 |
+
# shorten path to avoid drawing long boxes
|
| 185 |
+
# for full path = '/home/weif/pytorch/test.py'
|
| 186 |
+
# return short path = 'pytorch/test.py'
|
| 187 |
+
def _shorten_file_name(
|
| 188 |
+
self,
|
| 189 |
+
full_file_name: str,
|
| 190 |
+
truncate_to_last_n: int = 2,
|
| 191 |
+
):
|
| 192 |
+
splits = full_file_name.split('/')
|
| 193 |
+
if len(splits) >= truncate_to_last_n:
|
| 194 |
+
return '/'.join(splits[-truncate_to_last_n:])
|
| 195 |
+
return full_file_name
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def _get_node_label(
|
| 199 |
+
self,
|
| 200 |
+
module: torch.fx.GraphModule,
|
| 201 |
+
node: torch.fx.Node,
|
| 202 |
+
skip_node_names_in_args: bool,
|
| 203 |
+
parse_stack_trace: bool,
|
| 204 |
+
) -> str:
|
| 205 |
+
def _get_str_for_args_kwargs(arg):
|
| 206 |
+
if isinstance(arg, tuple):
|
| 207 |
+
prefix, suffix = r"|args=(\l", r",\n)\l"
|
| 208 |
+
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
|
| 209 |
+
elif isinstance(arg, dict):
|
| 210 |
+
prefix, suffix = r"|kwargs={\l", r",\n}\l"
|
| 211 |
+
arg_strs_list = [
|
| 212 |
+
f"{k}: {_format_arg(v, max_list_len=8)}"
|
| 213 |
+
for k, v in arg.items()
|
| 214 |
+
]
|
| 215 |
+
else: # Fall back to nothing in unexpected case.
|
| 216 |
+
return ""
|
| 217 |
+
|
| 218 |
+
# Strip out node names if requested.
|
| 219 |
+
if skip_node_names_in_args:
|
| 220 |
+
arg_strs_list = [a for a in arg_strs_list if "%" not in a]
|
| 221 |
+
if len(arg_strs_list) == 0:
|
| 222 |
+
return ""
|
| 223 |
+
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
|
| 224 |
+
if len(arg_strs_list) == 1:
|
| 225 |
+
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
|
| 226 |
+
return arg_strs.replace("{", r"\{").replace("}", r"\}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
label = "{" + f"name=%{node.name}|op_code={node.op}\n"
|
| 230 |
+
|
| 231 |
+
if node.op == "call_module":
|
| 232 |
+
leaf_module = self._get_leaf_node(module, node)
|
| 233 |
+
label += r"\n" + self._typename(leaf_module) + r"\n|"
|
| 234 |
+
extra = ""
|
| 235 |
+
if hasattr(leaf_module, "__constants__"):
|
| 236 |
+
extra = r"\n".join(
|
| 237 |
+
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
|
| 238 |
+
)
|
| 239 |
+
label += extra + r"\n"
|
| 240 |
+
else:
|
| 241 |
+
label += f"|target={self._typename(node.target)}" + r"\n"
|
| 242 |
+
if len(node.args) > 0:
|
| 243 |
+
label += _get_str_for_args_kwargs(node.args)
|
| 244 |
+
if len(node.kwargs) > 0:
|
| 245 |
+
label += _get_str_for_args_kwargs(node.kwargs)
|
| 246 |
+
label += f"|num_users={len(node.users)}" + r"\n"
|
| 247 |
+
|
| 248 |
+
tensor_meta = node.meta.get('tensor_meta')
|
| 249 |
+
label += self._tensor_meta_to_label(tensor_meta)
|
| 250 |
+
|
| 251 |
+
# for original fx graph
|
| 252 |
+
# print buf=buf0, n_origin=6
|
| 253 |
+
buf_meta = node.meta.get('buf_meta', None)
|
| 254 |
+
if buf_meta is not None:
|
| 255 |
+
label += f"|buf={buf_meta.name}" + r"\n"
|
| 256 |
+
label += f"|n_origin={buf_meta.n_origin}" + r"\n"
|
| 257 |
+
|
| 258 |
+
# for original fx graph
|
| 259 |
+
# print file:lineno code
|
| 260 |
+
if parse_stack_trace and node.stack_trace is not None:
|
| 261 |
+
parsed_stack_trace = _parse_stack_trace(node.stack_trace)
|
| 262 |
+
fname = self._shorten_file_name(parsed_stack_trace.file)
|
| 263 |
+
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
return label + "}"
|
| 267 |
+
|
| 268 |
+
def _tensor_meta_to_label(self, tm) -> str:
|
| 269 |
+
if tm is None:
|
| 270 |
+
return ""
|
| 271 |
+
elif isinstance(tm, TensorMetadata):
|
| 272 |
+
return self._stringify_tensor_meta(tm)
|
| 273 |
+
elif isinstance(tm, list):
|
| 274 |
+
result = ""
|
| 275 |
+
for item in tm:
|
| 276 |
+
result += self._tensor_meta_to_label(item)
|
| 277 |
+
return result
|
| 278 |
+
elif isinstance(tm, dict):
|
| 279 |
+
result = ""
|
| 280 |
+
for v in tm.values():
|
| 281 |
+
result += self._tensor_meta_to_label(v)
|
| 282 |
+
return result
|
| 283 |
+
elif isinstance(tm, tuple):
|
| 284 |
+
result = ""
|
| 285 |
+
for item in tm:
|
| 286 |
+
result += self._tensor_meta_to_label(item)
|
| 287 |
+
return result
|
| 288 |
+
else:
|
| 289 |
+
raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
|
| 290 |
+
|
| 291 |
+
def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
|
| 292 |
+
result = ""
|
| 293 |
+
if not hasattr(tm, "dtype"):
|
| 294 |
+
print("tm", tm)
|
| 295 |
+
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
|
| 296 |
+
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
|
| 297 |
+
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
|
| 298 |
+
result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
|
| 299 |
+
if tm.is_quantized:
|
| 300 |
+
assert tm.qparams is not None
|
| 301 |
+
assert "qscheme" in tm.qparams
|
| 302 |
+
qscheme = tm.qparams["qscheme"]
|
| 303 |
+
if qscheme in {
|
| 304 |
+
torch.per_tensor_affine,
|
| 305 |
+
torch.per_tensor_symmetric,
|
| 306 |
+
}:
|
| 307 |
+
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
|
| 308 |
+
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
|
| 309 |
+
elif qscheme in {
|
| 310 |
+
torch.per_channel_affine,
|
| 311 |
+
torch.per_channel_symmetric,
|
| 312 |
+
torch.per_channel_affine_float_qparams,
|
| 313 |
+
}:
|
| 314 |
+
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
|
| 315 |
+
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
|
| 316 |
+
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
|
| 317 |
+
else:
|
| 318 |
+
raise RuntimeError(f"Unsupported qscheme: {qscheme}")
|
| 319 |
+
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
|
| 320 |
+
return result
|
| 321 |
+
|
| 322 |
+
def _get_tensor_label(self, t: torch.Tensor) -> str:
|
| 323 |
+
return str(t.dtype) + str(list(t.shape)) + r"\n"
|
| 324 |
+
|
| 325 |
+
# when parse_stack_trace=True
|
| 326 |
+
# print file:lineno code
|
| 327 |
+
def _to_dot(
|
| 328 |
+
self,
|
| 329 |
+
graph_module: torch.fx.GraphModule,
|
| 330 |
+
name: str,
|
| 331 |
+
ignore_getattr: bool,
|
| 332 |
+
ignore_parameters_and_buffers: bool,
|
| 333 |
+
skip_node_names_in_args: bool,
|
| 334 |
+
parse_stack_trace: bool,
|
| 335 |
+
) -> pydot.Dot:
|
| 336 |
+
"""
|
| 337 |
+
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
|
| 338 |
+
If ignore_parameters_and_buffers is True, the parameters and buffers
|
| 339 |
+
created with the module will not be added as nodes and edges.
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
# "TB" means top-to-bottom rank direction in layout
|
| 343 |
+
dot_graph = pydot.Dot(name, rankdir="TB")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
buf_name_to_subgraph = {}
|
| 347 |
+
|
| 348 |
+
for node in graph_module.graph.nodes:
|
| 349 |
+
if ignore_getattr and node.op == "get_attr":
|
| 350 |
+
continue
|
| 351 |
+
|
| 352 |
+
style = self._get_node_style(node)
|
| 353 |
+
dot_node = pydot.Node(
|
| 354 |
+
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
current_graph = dot_graph
|
| 358 |
+
|
| 359 |
+
buf_meta = node.meta.get('buf_meta', None)
|
| 360 |
+
if buf_meta is not None and buf_meta.n_origin > 1:
|
| 361 |
+
buf_name = buf_meta.name
|
| 362 |
+
if buf_name not in buf_name_to_subgraph:
|
| 363 |
+
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name)
|
| 364 |
+
current_graph = buf_name_to_subgraph.get(buf_name)
|
| 365 |
+
|
| 366 |
+
current_graph.add_node(dot_node)
|
| 367 |
+
|
| 368 |
+
def get_module_params_or_buffers():
|
| 369 |
+
for pname, ptensor in chain(
|
| 370 |
+
leaf_module.named_parameters(), leaf_module.named_buffers()
|
| 371 |
+
):
|
| 372 |
+
pname1 = node.name + "." + pname
|
| 373 |
+
label1 = (
|
| 374 |
+
pname1 + "|op_code=get_" + "parameter"
|
| 375 |
+
if isinstance(ptensor, torch.nn.Parameter)
|
| 376 |
+
else "buffer" + r"\l"
|
| 377 |
+
)
|
| 378 |
+
dot_w_node = pydot.Node(
|
| 379 |
+
pname1,
|
| 380 |
+
label="{" + label1 + self._get_tensor_label(ptensor) + "}",
|
| 381 |
+
**_WEIGHT_TEMPLATE,
|
| 382 |
+
)
|
| 383 |
+
dot_graph.add_node(dot_w_node)
|
| 384 |
+
dot_graph.add_edge(pydot.Edge(pname1, node.name))
|
| 385 |
+
|
| 386 |
+
if node.op == "call_module":
|
| 387 |
+
leaf_module = self._get_leaf_node(graph_module, node)
|
| 388 |
+
|
| 389 |
+
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
|
| 390 |
+
get_module_params_or_buffers()
|
| 391 |
+
|
| 392 |
+
for subgraph in buf_name_to_subgraph.values():
|
| 393 |
+
subgraph.set('color', 'royalblue')
|
| 394 |
+
subgraph.set('penwidth', '2')
|
| 395 |
+
dot_graph.add_subgraph(subgraph)
|
| 396 |
+
|
| 397 |
+
for node in graph_module.graph.nodes:
|
| 398 |
+
if ignore_getattr and node.op == "get_attr":
|
| 399 |
+
continue
|
| 400 |
+
|
| 401 |
+
for user in node.users:
|
| 402 |
+
dot_graph.add_edge(pydot.Edge(node.name, user.name))
|
| 403 |
+
|
| 404 |
+
return dot_graph
|
| 405 |
+
|
| 406 |
+
else:
|
| 407 |
+
if not TYPE_CHECKING:
|
| 408 |
+
@compatibility(is_backward_compatible=False)
|
| 409 |
+
class FxGraphDrawer:
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
graph_module: torch.fx.GraphModule,
|
| 413 |
+
name: str,
|
| 414 |
+
ignore_getattr: bool = False,
|
| 415 |
+
ignore_parameters_and_buffers: bool = False,
|
| 416 |
+
skip_node_names_in_args: bool = True,
|
| 417 |
+
parse_stack_trace: bool = False,
|
| 418 |
+
dot_graph_shape: Optional[str] = None,
|
| 419 |
+
):
|
| 420 |
+
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
|
| 421 |
+
'pydot through your favorite Python package manager.')
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (273 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
| 2 |
+
import collections
|
| 3 |
+
import itertools
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
from copy import copy
|
| 7 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Set
|
| 8 |
+
|
| 9 |
+
from torch.fx.graph_module import GraphModule
|
| 10 |
+
from torch.fx.node import Node, _get_qualified_name
|
| 11 |
+
from torch.fx.passes.operator_support import OperatorSupportBase
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
logger.setLevel(logging.WARNING)
|
| 16 |
+
|
| 17 |
+
class Partition:
|
| 18 |
+
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
|
| 19 |
+
self.id = id
|
| 20 |
+
self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
|
| 21 |
+
|
| 22 |
+
def __repr__(self) -> str:
|
| 23 |
+
return str(self.nodes)
|
| 24 |
+
|
| 25 |
+
def add_node(self, node: Node):
|
| 26 |
+
self.nodes.add(node)
|
| 27 |
+
|
| 28 |
+
def remove_node(self, node: Node):
|
| 29 |
+
self.nodes.remove(node)
|
| 30 |
+
|
| 31 |
+
def size(self):
|
| 32 |
+
return len(self.nodes)
|
| 33 |
+
|
| 34 |
+
class _DependencyViewer:
|
| 35 |
+
def __init__(self, graph_module: GraphModule):
|
| 36 |
+
self.upstreams = collections.defaultdict(set)
|
| 37 |
+
self.downstreams = collections.defaultdict(set)
|
| 38 |
+
|
| 39 |
+
for node in graph_module.graph.nodes:
|
| 40 |
+
for input_node in node.all_input_nodes:
|
| 41 |
+
# add input_node and input_node's upstream dependency
|
| 42 |
+
self.upstreams[node].add(input_node)
|
| 43 |
+
self.upstreams[node].update(self.upstreams[input_node])
|
| 44 |
+
|
| 45 |
+
for node in reversed(graph_module.graph.nodes):
|
| 46 |
+
for output_node in node.users:
|
| 47 |
+
# add output_node and output_node's downstream dependency
|
| 48 |
+
self.downstreams[node].add(output_node)
|
| 49 |
+
self.downstreams[node].update(self.downstreams[output_node])
|
| 50 |
+
|
| 51 |
+
def downstreams_of(self, node: Node) -> Set[Node]:
|
| 52 |
+
return self.downstreams[node]
|
| 53 |
+
|
| 54 |
+
def upstreams_of(self, node: Node) -> Set[Node]:
|
| 55 |
+
return self.upstreams[node]
|
| 56 |
+
|
| 57 |
+
class CapabilityBasedPartitioner:
|
| 58 |
+
|
| 59 |
+
def __init__(self,
|
| 60 |
+
graph_module: GraphModule,
|
| 61 |
+
operator_support: OperatorSupportBase,
|
| 62 |
+
allows_single_node_partition: bool = False,
|
| 63 |
+
non_compute_ops: Optional[Sequence[str]] = None,
|
| 64 |
+
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
|
| 65 |
+
) -> None:
|
| 66 |
+
self.graph_module = graph_module
|
| 67 |
+
self.operator_support = operator_support
|
| 68 |
+
self.allows_single_node_partition = allows_single_node_partition
|
| 69 |
+
self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
|
| 70 |
+
self.allowed_single_node_partition_ops = (
|
| 71 |
+
allowed_single_node_partition_ops
|
| 72 |
+
if allowed_single_node_partition_ops is not None
|
| 73 |
+
else []
|
| 74 |
+
)
|
| 75 |
+
self.dependency_viewer = _DependencyViewer(graph_module)
|
| 76 |
+
|
| 77 |
+
def __is_node_supported(self, node: Node) -> bool:
|
| 78 |
+
return (
|
| 79 |
+
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def propose_partitions(self) -> List[Partition]:
|
| 83 |
+
# partition_map is a mapping from partition id to a set of partition id's.
|
| 84 |
+
# The value set contains all the partition ids that can be reached by doing a
|
| 85 |
+
# DFS starting from the partition id in the key.
|
| 86 |
+
partition_map : Dict[int, Set] = collections.defaultdict(set)
|
| 87 |
+
|
| 88 |
+
# assumptions: nodes in candidate list is sorted in topological order
|
| 89 |
+
assignment: Dict[Node, int] = {} # mapping from node to partition_id
|
| 90 |
+
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
|
| 91 |
+
new_partition_id = itertools.count()
|
| 92 |
+
|
| 93 |
+
# try to merge partition other_id into partition self_id
|
| 94 |
+
# merge only happens if the end graph doesn't contain cyclic dependency
|
| 95 |
+
# returns `True` when merge happens, `False` otherwise.
|
| 96 |
+
def maybe_merge_partition(self_id: int, other_id: int):
|
| 97 |
+
# merged_nodes is the union of nodes in two partition to-be-merged
|
| 98 |
+
merged_nodes = copy(partitions_by_id[self_id].nodes)
|
| 99 |
+
merged_nodes.update(partitions_by_id[other_id].nodes)
|
| 100 |
+
|
| 101 |
+
def dfs_iter_find_cycle(all_user_nodes: List[Node]):
|
| 102 |
+
for user_node in all_user_nodes:
|
| 103 |
+
visited_partition_ids = set()
|
| 104 |
+
|
| 105 |
+
for path_node in self.dependency_viewer.downstreams_of(user_node):
|
| 106 |
+
# If any of the nodes in the dfs path of this node are in the merged_nodes
|
| 107 |
+
# list then there is a cycle in the graph.
|
| 108 |
+
if path_node in merged_nodes:
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
# If any of the nodes in the dfs path of this node are in the assignment
|
| 112 |
+
# map then we have to make sure that the partitions that these nodes belong
|
| 113 |
+
# to do not form a cycle with the current partitions being merged. This means
|
| 114 |
+
# iterating through all the nodes in all the parititons that are traversed in
|
| 115 |
+
# the dfs path and checking if they are in the merged_nodes list.
|
| 116 |
+
if path_node in assignment:
|
| 117 |
+
partition_id = assignment[path_node]
|
| 118 |
+
# If the partition id has already been visited then we know that it doesn't
|
| 119 |
+
# form a cycle with the current partitions being merged.
|
| 120 |
+
if partition_id in visited_partition_ids:
|
| 121 |
+
continue
|
| 122 |
+
p_map = partition_map[partition_id]
|
| 123 |
+
if self_id in p_map or other_id in p_map:
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
visited_partition_ids.add(partition_id)
|
| 127 |
+
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
# check if merge would create cyclic dependency.
|
| 131 |
+
all_user_nodes = []
|
| 132 |
+
for node in merged_nodes:
|
| 133 |
+
for user_node in node.users:
|
| 134 |
+
if user_node not in merged_nodes:
|
| 135 |
+
all_user_nodes.append(user_node)
|
| 136 |
+
|
| 137 |
+
if dfs_iter_find_cycle(all_user_nodes):
|
| 138 |
+
# return false indicating cyclic dependency found and
|
| 139 |
+
# merge is aborted
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
# no cyclic dependency found, move forward with the merge
|
| 143 |
+
# updating partition nodes
|
| 144 |
+
partitions_by_id[self_id].nodes = merged_nodes
|
| 145 |
+
# updating assignment map
|
| 146 |
+
for node in partitions_by_id[other_id].nodes:
|
| 147 |
+
assignment[node] = self_id
|
| 148 |
+
# delete other partition
|
| 149 |
+
del partitions_by_id[other_id]
|
| 150 |
+
|
| 151 |
+
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
|
| 152 |
+
del partition_map[other_id]
|
| 153 |
+
|
| 154 |
+
return True
|
| 155 |
+
|
| 156 |
+
def merge_single_node(node: Node, id: Optional[int]):
|
| 157 |
+
def _update_partition_map(node: Node, id: int):
|
| 158 |
+
# Iterate through all the downstream nodes of this node and update the partition map
|
| 159 |
+
# to indicate that there is a path from the partition id of this node to the target
|
| 160 |
+
# partition id.
|
| 161 |
+
downstream_nodes = self.dependency_viewer.downstreams_of(node)
|
| 162 |
+
for curr_node in downstream_nodes:
|
| 163 |
+
target_id = assignment.get(curr_node, None)
|
| 164 |
+
if target_id is not None:
|
| 165 |
+
partition_map[id].add(target_id)
|
| 166 |
+
|
| 167 |
+
# Iterate through all the upstream nodes of this node and update the partition map
|
| 168 |
+
# to indicate that there is a path from the partition id of the upstream node to the
|
| 169 |
+
# current node's partition id.
|
| 170 |
+
upstream_nodes = self.dependency_viewer.upstreams_of(node)
|
| 171 |
+
for curr_node in upstream_nodes:
|
| 172 |
+
source_id = assignment.get(curr_node, None)
|
| 173 |
+
if source_id is not None:
|
| 174 |
+
partition_map[source_id].add(id)
|
| 175 |
+
|
| 176 |
+
if node in assignment:
|
| 177 |
+
partitions_by_id[assignment[node]].remove_node(node)
|
| 178 |
+
|
| 179 |
+
if id is None:
|
| 180 |
+
assignment.pop(node)
|
| 181 |
+
elif id not in partitions_by_id:
|
| 182 |
+
assignment[node] = id
|
| 183 |
+
partitions_by_id[id] = Partition(id=id, nodes=[node])
|
| 184 |
+
_update_partition_map(node, id)
|
| 185 |
+
else:
|
| 186 |
+
assignment[node] = id
|
| 187 |
+
partitions_by_id[id].add_node(node)
|
| 188 |
+
_update_partition_map(node, id)
|
| 189 |
+
|
| 190 |
+
logger.debug("Proposing partitions...")
|
| 191 |
+
|
| 192 |
+
for node in reversed(self.graph_module.graph.nodes):
|
| 193 |
+
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
|
| 194 |
+
merge_candidates: Dict[int, None] = {}
|
| 195 |
+
|
| 196 |
+
# Note a limited horizontal fusion is enabled:
|
| 197 |
+
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
|
| 198 |
+
#
|
| 199 |
+
# I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
|
| 200 |
+
# the fusion by adding an `else` block here to skip horizontal fusion.
|
| 201 |
+
if self.__is_node_supported(node) and node not in assignment:
|
| 202 |
+
partition_id = next(new_partition_id)
|
| 203 |
+
merge_single_node(node, partition_id)
|
| 204 |
+
merge_candidates[partition_id] = None
|
| 205 |
+
|
| 206 |
+
# merge all possible partitions
|
| 207 |
+
for node in assignment:
|
| 208 |
+
merge_candidates[assignment[node]] = None
|
| 209 |
+
|
| 210 |
+
merge_candidates_list = list(merge_candidates.keys())
|
| 211 |
+
if len(merge_candidates_list) > 1:
|
| 212 |
+
self_id = merge_candidates_list[0]
|
| 213 |
+
for other_id in merge_candidates_list[1:]:
|
| 214 |
+
# note: merge partition `other_id` into partition `self_id` if
|
| 215 |
+
# it doesn't create cyclic dependency in the graph, otherwise,
|
| 216 |
+
# this is a no-op
|
| 217 |
+
maybe_merge_partition(self_id, other_id)
|
| 218 |
+
|
| 219 |
+
# post processing to re-assign "getitem" nodes into upstream partition
|
| 220 |
+
logger.debug("Reassigning getitem nodes to its producer node's partition...")
|
| 221 |
+
nodes_reassignment: Dict[Node, int] = {}
|
| 222 |
+
for node in self.graph_module.graph.nodes:
|
| 223 |
+
is_tuple_output = True
|
| 224 |
+
for user in node.users:
|
| 225 |
+
if user.op != "call_function" or \
|
| 226 |
+
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
|
| 227 |
+
is_tuple_output = False
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
# node has tuple outputs, re-assign all following getitem node into node's partition
|
| 231 |
+
if is_tuple_output:
|
| 232 |
+
id = assignment.get(node, None) # type: ignore[arg-type]
|
| 233 |
+
for user in node.users:
|
| 234 |
+
if assignment.get(user, None) != id: # type: ignore[arg-type]
|
| 235 |
+
nodes_reassignment[user] = id # type: ignore[assignment]
|
| 236 |
+
for node, id in nodes_reassignment.items():
|
| 237 |
+
merge_single_node(node, id)
|
| 238 |
+
|
| 239 |
+
# filter out single node partitions
|
| 240 |
+
if not self.allows_single_node_partition:
|
| 241 |
+
logger.debug("Filtering out single node partitions...")
|
| 242 |
+
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
|
| 243 |
+
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
|
| 244 |
+
partitions_to_remove: List[int] = []
|
| 245 |
+
for id, partition in partitions_by_id.items():
|
| 246 |
+
compute_node_count = 0
|
| 247 |
+
for node in partition.nodes:
|
| 248 |
+
if node.op == "call_function":
|
| 249 |
+
assert callable(node.target)
|
| 250 |
+
if _get_qualified_name(node.target) not in non_compute_ops:
|
| 251 |
+
compute_node_count += 1
|
| 252 |
+
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
|
| 253 |
+
compute_node_count += 1
|
| 254 |
+
if compute_node_count <= 1:
|
| 255 |
+
partitions_to_remove.append(id)
|
| 256 |
+
for id in partitions_to_remove:
|
| 257 |
+
del partitions_by_id[id]
|
| 258 |
+
|
| 259 |
+
logger.debug("Partitions proposed:")
|
| 260 |
+
for id, partition in partitions_by_id.items():
|
| 261 |
+
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
|
| 262 |
+
|
| 263 |
+
return list(partitions_by_id.values())
|
| 264 |
+
|
| 265 |
+
def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
|
| 266 |
+
logger.debug("Fusing partitions...")
|
| 267 |
+
# fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
|
| 268 |
+
return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
|
| 269 |
+
|
| 270 |
+
# remove non-compute-ops that sits at the boundary of a partition.
|
| 271 |
+
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
|
| 272 |
+
non_compute_ops = set(self.non_compute_ops)
|
| 273 |
+
|
| 274 |
+
def is_non_compute_node(node: Node):
|
| 275 |
+
return node.op == "call_function" and \
|
| 276 |
+
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
|
| 277 |
+
|
| 278 |
+
# cache transparent nodes
|
| 279 |
+
transparent_input_nodes: Dict[Node, bool] = {}
|
| 280 |
+
transparent_output_nodes: Dict[Node, bool] = {}
|
| 281 |
+
|
| 282 |
+
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
|
| 283 |
+
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
|
| 284 |
+
return True
|
| 285 |
+
if node in transparent_input_nodes:
|
| 286 |
+
return transparent_input_nodes[node]
|
| 287 |
+
if is_non_compute_node(node):
|
| 288 |
+
for input_n in node.all_input_nodes:
|
| 289 |
+
if not is_transparent_input_node(input_n, partition, removed_nodes):
|
| 290 |
+
transparent_input_nodes[node] = False
|
| 291 |
+
return False
|
| 292 |
+
transparent_input_nodes[node] = True
|
| 293 |
+
return True
|
| 294 |
+
transparent_input_nodes[node] = False
|
| 295 |
+
return False
|
| 296 |
+
|
| 297 |
+
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
|
| 298 |
+
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
|
| 299 |
+
return True
|
| 300 |
+
if node in transparent_output_nodes:
|
| 301 |
+
return transparent_output_nodes[node]
|
| 302 |
+
if is_non_compute_node(node):
|
| 303 |
+
for output_n in node.users:
|
| 304 |
+
if not is_transparent_output_node(output_n, partition, removed_nodes):
|
| 305 |
+
transparent_output_nodes[node] = False
|
| 306 |
+
return False
|
| 307 |
+
transparent_output_nodes[node] = True
|
| 308 |
+
return True
|
| 309 |
+
transparent_output_nodes[node] = False
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
for partition in partitions:
|
| 313 |
+
# Note it's ok to use `set` here, since we are only query if a node
|
| 314 |
+
# has been removed. We are NEVER going to iterate on nodes inside
|
| 315 |
+
# the set.
|
| 316 |
+
remove_node: Set[Node] = set()
|
| 317 |
+
for node in partition.nodes:
|
| 318 |
+
if is_non_compute_node(node) and \
|
| 319 |
+
(is_transparent_input_node(node, partition.nodes, remove_node) or
|
| 320 |
+
is_transparent_output_node(node, partition.nodes, remove_node)):
|
| 321 |
+
remove_node.add(node)
|
| 322 |
+
|
| 323 |
+
if len(remove_node) != 0:
|
| 324 |
+
partition.nodes = partition.nodes - remove_node
|
| 325 |
+
|
| 326 |
+
def partition_and_fuse(self) -> GraphModule:
|
| 327 |
+
partitions = self.propose_partitions()
|
| 328 |
+
fused_gm = self.fuse_partitions(partitions)
|
| 329 |
+
return fused_gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import typing as t
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.fx
|
| 6 |
+
from torch.fx._compatibility import compatibility
|
| 7 |
+
from .shape_prop import TensorMetadata
|
| 8 |
+
from .tools_common import get_node_target, CALLABLE_NODE_OPS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain']
|
| 12 |
+
|
| 13 |
+
# fx.Node.target typename, as returned by `get_node_target()`
|
| 14 |
+
TargetTypeName = str
|
| 15 |
+
|
| 16 |
+
# Arguments' dtypes for a given node, see `OperatorSupport`
|
| 17 |
+
SupportedArgumentDTypes = t.Optional[
|
| 18 |
+
t.Tuple[
|
| 19 |
+
t.Sequence[t.Sequence[torch.dtype]],
|
| 20 |
+
t.Dict[str, t.Sequence[torch.dtype]],
|
| 21 |
+
]
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@compatibility(is_backward_compatible=False)
|
| 28 |
+
class OperatorSupportBase(abc.ABC):
|
| 29 |
+
"""Interface for determining if a fx.Node is supported by a backend"""
|
| 30 |
+
@abc.abstractmethod
|
| 31 |
+
def is_node_supported(
|
| 32 |
+
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
| 33 |
+
) -> bool:
|
| 34 |
+
raise NotImplementedError()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@compatibility(is_backward_compatible=False)
|
| 38 |
+
class OperatorSupport(OperatorSupportBase):
|
| 39 |
+
"""
|
| 40 |
+
`_support_dict` maps node.target typename to supported inputs dtypes.
|
| 41 |
+
|
| 42 |
+
node.target typename is retrieved using helper function `get_node_target()`
|
| 43 |
+
|
| 44 |
+
If supported inputs dtypes is None, it means any dtype is supported, else
|
| 45 |
+
we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
|
| 46 |
+
|
| 47 |
+
The first tuple ([dtypes], ...) indicates what dtypes are supported for
|
| 48 |
+
inputs in node.args and the second dict {"name": [dtypes], ...} indicates
|
| 49 |
+
what dtypes are supported for inputs in node.kwargs.
|
| 50 |
+
|
| 51 |
+
For inputs in args, if we don't want to check it, we can put None there,
|
| 52 |
+
e.g. (None, [torch.float]) indicates that we don't care about the type of
|
| 53 |
+
the first input in args. And for inputs in kwargs, if not listed, will not
|
| 54 |
+
be checked.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
_support_dict: SupportDict
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
support_dict: t.Optional[SupportDict] = None
|
| 62 |
+
):
|
| 63 |
+
self._support_dict = support_dict or {}
|
| 64 |
+
|
| 65 |
+
def is_node_supported(
|
| 66 |
+
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
| 67 |
+
) -> bool:
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
`submodules`: mapping from module name to the module. This can be
|
| 71 |
+
retrieved by calling model.named_modules().
|
| 72 |
+
|
| 73 |
+
`node`: a Fx node that we want to determine whether it's supported.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
`is_supported`: whether the arg `node` is supported.
|
| 77 |
+
"""
|
| 78 |
+
if node.op not in CALLABLE_NODE_OPS:
|
| 79 |
+
return True
|
| 80 |
+
|
| 81 |
+
target = get_node_target(submodules, node)
|
| 82 |
+
|
| 83 |
+
# Target not found in _support_dict meaning that we don't support this op at all
|
| 84 |
+
if target not in self._support_dict:
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
# The rule for target is None meaning that we accept any dtype
|
| 88 |
+
if self._support_dict[target] is None:
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
|
| 92 |
+
|
| 93 |
+
# Check args dtypes
|
| 94 |
+
for i, dtypes in enumerate(args_dtypes):
|
| 95 |
+
if len(node.args) <= i:
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
# None indicates we don't care about the dtype of args[i]
|
| 99 |
+
if dtypes is None:
|
| 100 |
+
continue
|
| 101 |
+
|
| 102 |
+
# If arg is not a node then we don't check it
|
| 103 |
+
if not isinstance(node.args[i], torch.fx.Node):
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
|
| 107 |
+
if arg_dtype not in dtypes:
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# Check kwargs dtypes
|
| 111 |
+
for k, dtypes in kwargs_dtypes.items():
|
| 112 |
+
if k not in node.kwargs:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
# If arg is not a node then we don't check it
|
| 116 |
+
if not isinstance(node.kwargs[k], torch.fx.Node):
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
|
| 120 |
+
if kwarg_dtype not in dtypes:
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
return True
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ======================================================================
|
| 127 |
+
# Functional interfaces and utils for defining basic operator support logic
|
| 128 |
+
# and composing them into more complex ones
|
| 129 |
+
# ======================================================================
|
| 130 |
+
|
| 131 |
+
IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@compatibility(is_backward_compatible=False)
|
| 135 |
+
def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
|
| 136 |
+
"""Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
|
| 137 |
+
|
| 138 |
+
`IsNodeSupported` has the same call signature as
|
| 139 |
+
`OperatorSupportBase.is_node_supported`
|
| 140 |
+
"""
|
| 141 |
+
class FunctionalOperatorSupport(OperatorSupportBase):
|
| 142 |
+
def is_node_supported(
|
| 143 |
+
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
|
| 144 |
+
) -> bool:
|
| 145 |
+
return is_node_supported(submodules, node)
|
| 146 |
+
return FunctionalOperatorSupport()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@compatibility(is_backward_compatible=False)
|
| 150 |
+
def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
|
| 151 |
+
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
|
| 152 |
+
instance by evaluating each input `OperatorSupportBase` instance, and returns False if
|
| 153 |
+
any of it reports False.
|
| 154 |
+
"""
|
| 155 |
+
def _chain(submods, node) -> bool:
|
| 156 |
+
return all(
|
| 157 |
+
x.is_node_supported(submods, node)
|
| 158 |
+
for x in op_support
|
| 159 |
+
)
|
| 160 |
+
return create_op_support(_chain)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@compatibility(is_backward_compatible=False)
|
| 164 |
+
def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
|
| 165 |
+
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
|
| 166 |
+
instance by evaluating each input `OperatorSupportBase` instance, and returns True if
|
| 167 |
+
any of it reports True.
|
| 168 |
+
"""
|
| 169 |
+
def _any_chain(submods, node) -> bool:
|
| 170 |
+
return any(
|
| 171 |
+
x.is_node_supported(submods, node)
|
| 172 |
+
for x in op_support
|
| 173 |
+
)
|
| 174 |
+
return create_op_support(_any_chain)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@compatibility(is_backward_compatible=False)
|
| 178 |
+
class OpSupports:
|
| 179 |
+
"""A set of atomic `OperatorSupportBase` instances that can be combined together
|
| 180 |
+
to form more complex operator support logic.
|
| 181 |
+
"""
|
| 182 |
+
@classmethod
|
| 183 |
+
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
|
| 184 |
+
"""Report a node as non-supported, if any of its arguments is of dtype"""
|
| 185 |
+
|
| 186 |
+
def _decline_if_input_dtype(
|
| 187 |
+
submodules: t.Mapping[str, torch.nn.Module],
|
| 188 |
+
node: torch.fx.Node,
|
| 189 |
+
) -> bool:
|
| 190 |
+
for arg in node.all_input_nodes:
|
| 191 |
+
arg_dtype = _get_arg_dtype(arg)
|
| 192 |
+
if arg_dtype == dtype:
|
| 193 |
+
return False
|
| 194 |
+
return True
|
| 195 |
+
return create_op_support(_decline_if_input_dtype)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
|
| 199 |
+
"""
|
| 200 |
+
If a node has a name that is in the disallow set, reported it as non-supported.
|
| 201 |
+
"""
|
| 202 |
+
def _decline_if_node_in_names(
|
| 203 |
+
submodules: t.Mapping[str, torch.nn.Module],
|
| 204 |
+
node: torch.fx.Node,
|
| 205 |
+
) -> bool:
|
| 206 |
+
if node.name in disallow_set:
|
| 207 |
+
return False
|
| 208 |
+
else:
|
| 209 |
+
return True
|
| 210 |
+
return create_op_support(_decline_if_node_in_names)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
|
| 214 |
+
assert isinstance(arg, torch.fx.Node)
|
| 215 |
+
tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
|
| 216 |
+
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
|
| 217 |
+
return dtype
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.fx.graph_module import GraphModule
|
| 2 |
+
from typing import Any, Callable, Dict, List, Tuple, Type
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from torch.fx._compatibility import compatibility
|
| 7 |
+
|
| 8 |
+
__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes']
|
| 9 |
+
|
| 10 |
+
# Matching method matches the attribute name of current version to the attribute name of `target_version`
|
| 11 |
+
@compatibility(is_backward_compatible=False)
|
| 12 |
+
def default_matching(name: str, target_version: int) -> str:
|
| 13 |
+
"""Default matching method
|
| 14 |
+
"""
|
| 15 |
+
return name
|
| 16 |
+
|
| 17 |
+
# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering.
|
| 18 |
+
# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list.
|
| 19 |
+
# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module.
|
| 20 |
+
module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = {
|
| 21 |
+
torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching),
|
| 22 |
+
torch.nn.modules.conv.Conv2d: (
|
| 23 |
+
1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching
|
| 24 |
+
),
|
| 25 |
+
torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching),
|
| 26 |
+
torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching),
|
| 27 |
+
torch.nn.modules.pooling.MaxPool2d: (
|
| 28 |
+
1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching
|
| 29 |
+
),
|
| 30 |
+
torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching),
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@compatibility(is_backward_compatible=False)
|
| 34 |
+
def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]:
|
| 35 |
+
"""If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book`
|
| 36 |
+
after checking module's version is compatible with the `module_fetch_book`.
|
| 37 |
+
"""
|
| 38 |
+
attrs_for_lowering: Dict[str, Any] = {}
|
| 39 |
+
attrs_for_lowering["name"] = torch.typename(mod)
|
| 40 |
+
|
| 41 |
+
if type(mod) in module_fetch_book:
|
| 42 |
+
version, param_to_fetch, matching_method = module_fetch_book[type(mod)]
|
| 43 |
+
if version < mod._version:
|
| 44 |
+
raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, "
|
| 45 |
+
"please upgrade the module_fetch_book, open an issue and @842974287 "
|
| 46 |
+
"or report a bug to AIACC team directly.")
|
| 47 |
+
for attr in param_to_fetch:
|
| 48 |
+
attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version))
|
| 49 |
+
else:
|
| 50 |
+
raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, "
|
| 51 |
+
"please add it to the module_fetch_book, open an issue and @842974287 "
|
| 52 |
+
"or report a bug to AIACC team directly.")
|
| 53 |
+
return attrs_for_lowering
|
| 54 |
+
|
| 55 |
+
@compatibility(is_backward_compatible=False)
|
| 56 |
+
def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None:
|
| 57 |
+
"""Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.
|
| 58 |
+
"""
|
| 59 |
+
submodules = dict(fx_module.named_modules())
|
| 60 |
+
|
| 61 |
+
for node in fx_module.graph.nodes:
|
| 62 |
+
if node.op == "call_module":
|
| 63 |
+
if isinstance(submodules[node.target], GraphModule):
|
| 64 |
+
lift_lowering_attrs_to_nodes(submodules[node.target])
|
| 65 |
+
else:
|
| 66 |
+
node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target])
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.fx
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
from torch._dispatch.python import enable_python_dispatcher
|
| 8 |
+
from torch.fx.node import Node, map_aggregate
|
| 9 |
+
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
| 10 |
+
from torch.fx._compatibility import compatibility
|
| 11 |
+
from torch._guards import detect_fake_mode
|
| 12 |
+
|
| 13 |
+
__all__ = ['TensorMetadata', 'ShapeProp']
|
| 14 |
+
|
| 15 |
+
@compatibility(is_backward_compatible=True)
|
| 16 |
+
class TensorMetadata(NamedTuple):
|
| 17 |
+
# TensorMetadata is a structure containing pertinent information
|
| 18 |
+
# about a tensor within a PyTorch program.
|
| 19 |
+
|
| 20 |
+
# General Tensor metadata
|
| 21 |
+
shape : torch.Size
|
| 22 |
+
dtype : torch.dtype
|
| 23 |
+
requires_grad : bool
|
| 24 |
+
stride : Tuple[int, ...]
|
| 25 |
+
memory_format : Optional[torch.memory_format]
|
| 26 |
+
|
| 27 |
+
# Quantization metadata
|
| 28 |
+
is_quantized : bool
|
| 29 |
+
qparams: Dict[str, Any]
|
| 30 |
+
|
| 31 |
+
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
|
| 32 |
+
"""
|
| 33 |
+
Extract a TensorMetadata NamedTuple describing `result`.
|
| 34 |
+
"""
|
| 35 |
+
shape = result.shape
|
| 36 |
+
dtype = result.dtype
|
| 37 |
+
requires_grad = result.requires_grad
|
| 38 |
+
stride = result.stride()
|
| 39 |
+
|
| 40 |
+
memory_format = None
|
| 41 |
+
|
| 42 |
+
if include_contiguity:
|
| 43 |
+
memory_formats = {
|
| 44 |
+
torch.contiguous_format,
|
| 45 |
+
torch.channels_last,
|
| 46 |
+
torch.channels_last_3d,
|
| 47 |
+
}
|
| 48 |
+
for query_format in memory_formats:
|
| 49 |
+
if result.is_contiguous(memory_format=query_format):
|
| 50 |
+
memory_format = query_format
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
is_quantized = result.is_quantized
|
| 54 |
+
qparams: Dict[str, Any] = {}
|
| 55 |
+
if is_quantized:
|
| 56 |
+
qscheme = result.qscheme()
|
| 57 |
+
qparams["qscheme"] = qscheme
|
| 58 |
+
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
| 59 |
+
qparams["scale"] = result.q_scale() # type: ignore[assignment]
|
| 60 |
+
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
|
| 61 |
+
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
|
| 62 |
+
# In this branch, scale and zero_point are expected to be tensors,
|
| 63 |
+
# we store the values as immutable_list in TensorMetadata for
|
| 64 |
+
# easier serialization downstream
|
| 65 |
+
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
|
| 66 |
+
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
|
| 67 |
+
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
|
| 68 |
+
|
| 69 |
+
return TensorMetadata(
|
| 70 |
+
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
|
| 71 |
+
|
| 72 |
+
@compatibility(is_backward_compatible=True)
|
| 73 |
+
class ShapeProp(torch.fx.Interpreter):
|
| 74 |
+
"""
|
| 75 |
+
Execute an FX graph Node-by-Node and
|
| 76 |
+
record the shape and type of the result
|
| 77 |
+
into the corresponding node.
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
In this example, we record the shape
|
| 81 |
+
and data type of a module given
|
| 82 |
+
an example input ``torch.randn(50, D_in)``.
|
| 83 |
+
We print the name, shape and dtype of each node.
|
| 84 |
+
|
| 85 |
+
class TwoLayerNet(torch.nn.Module):
|
| 86 |
+
def __init__(self, D_in, H, D_out):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.linear1 = torch.nn.Linear(D_in, H)
|
| 89 |
+
self.linear2 = torch.nn.Linear(H, D_out)
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
h_relu = self.linear1(x).clamp(min=0)
|
| 92 |
+
y_pred = self.linear2(h_relu)
|
| 93 |
+
return y_pred
|
| 94 |
+
N, D_in, H, D_out = 64, 1000, 100, 10
|
| 95 |
+
x = torch.randn(N, D_in)
|
| 96 |
+
y = torch.randn(N, D_out)
|
| 97 |
+
model = TwoLayerNet(D_in, H, D_out)
|
| 98 |
+
gm = torch.fx.symbolic_trace(model)
|
| 99 |
+
sample_input = torch.randn(50, D_in)
|
| 100 |
+
ShapeProp(gm).propagate(sample_input)
|
| 101 |
+
|
| 102 |
+
for node in gm.graph.nodes:
|
| 103 |
+
print(node.name, node.meta['tensor_meta'].dtype,
|
| 104 |
+
node.meta['tensor_meta'].shape)
|
| 105 |
+
|
| 106 |
+
The output of this code is:
|
| 107 |
+
|
| 108 |
+
x torch.float32 torch.Size([50, 1000])
|
| 109 |
+
linear1 torch.float32 torch.Size([50, 100])
|
| 110 |
+
clamp_1 torch.float32 torch.Size([50, 100])
|
| 111 |
+
linear2 torch.float32 torch.Size([50, 10])
|
| 112 |
+
output torch.float32 torch.Size([50, 10])
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
module (GraphModule): The module to be executed
|
| 116 |
+
fake_mode (FakeTensorMode): A fake mode for copying the gm
|
| 117 |
+
|
| 118 |
+
"""
|
| 119 |
+
def __init__(self, gm, fake_mode=None):
|
| 120 |
+
super().__init__(gm)
|
| 121 |
+
if fake_mode is None:
|
| 122 |
+
fake_mode = detect_fake_mode()
|
| 123 |
+
if fake_mode is not None:
|
| 124 |
+
from torch._dynamo.utils import deepcopy_to_fake_tensor
|
| 125 |
+
# Note:
|
| 126 |
+
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
|
| 127 |
+
# - because we need to write to the tensor_meta of the real module. So we fakify to
|
| 128 |
+
# produce a result (L131 below), to extract tensor meta, and then keep going.
|
| 129 |
+
#
|
| 130 |
+
# If we were to fakify, we would write to the wrong node, and then downstream fusion
|
| 131 |
+
# would be missing the tensor_meta.
|
| 132 |
+
#
|
| 133 |
+
# See torch/_inductor/overrides.py for where this is called upstream of fusion.
|
| 134 |
+
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
|
| 135 |
+
self.fake_mode = fake_mode
|
| 136 |
+
else:
|
| 137 |
+
self.fake_module = None
|
| 138 |
+
self.fake_mode = None
|
| 139 |
+
|
| 140 |
+
self.real_module = self.module
|
| 141 |
+
|
| 142 |
+
def run_node(self, n : Node) -> Any:
|
| 143 |
+
try:
|
| 144 |
+
if self.fake_module is not None:
|
| 145 |
+
# Hacky swap. Alternatively, we could do this with overriding
|
| 146 |
+
# call_module and get_attr.
|
| 147 |
+
self.module = self.fake_module
|
| 148 |
+
try:
|
| 149 |
+
if self.fake_mode is not None:
|
| 150 |
+
with self.fake_mode, enable_python_dispatcher():
|
| 151 |
+
result = super().run_node(n)
|
| 152 |
+
else:
|
| 153 |
+
result = super().run_node(n)
|
| 154 |
+
finally:
|
| 155 |
+
self.module = self.real_module
|
| 156 |
+
except Exception as e:
|
| 157 |
+
traceback.print_exc()
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
f"ShapeProp error for: node={n.format_node()} with "
|
| 160 |
+
f"meta={n.meta}"
|
| 161 |
+
) from e
|
| 162 |
+
|
| 163 |
+
found_tensor = False
|
| 164 |
+
|
| 165 |
+
def extract_tensor_meta(obj):
|
| 166 |
+
if isinstance(obj, torch.Tensor):
|
| 167 |
+
nonlocal found_tensor
|
| 168 |
+
found_tensor = True
|
| 169 |
+
return _extract_tensor_metadata(obj)
|
| 170 |
+
else:
|
| 171 |
+
return obj
|
| 172 |
+
|
| 173 |
+
meta = map_aggregate(result, extract_tensor_meta)
|
| 174 |
+
if found_tensor:
|
| 175 |
+
n.meta['tensor_meta'] = meta
|
| 176 |
+
|
| 177 |
+
n.meta['type'] = type(result)
|
| 178 |
+
return result
|
| 179 |
+
|
| 180 |
+
def propagate(self, *args):
|
| 181 |
+
"""
|
| 182 |
+
Run `module` via interpretation and return the result and
|
| 183 |
+
record the shape and type of each node.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
*args (Tensor): the sample input.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Any: The value returned from executing the Module
|
| 190 |
+
"""
|
| 191 |
+
if self.fake_mode is not None:
|
| 192 |
+
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
|
| 193 |
+
else:
|
| 194 |
+
fake_args = args
|
| 195 |
+
return super().run(*fake_args)
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 4 |
+
|
| 5 |
+
import torch.fx
|
| 6 |
+
from torch.fx._compatibility import compatibility
|
| 7 |
+
from torch.fx.graph import map_arg
|
| 8 |
+
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
|
| 9 |
+
|
| 10 |
+
from .tools_common import NodeList
|
| 11 |
+
|
| 12 |
+
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@compatibility(is_backward_compatible=False)
|
| 16 |
+
def getattr_recursive(obj, name):
|
| 17 |
+
for layer in name.split("."):
|
| 18 |
+
if hasattr(obj, layer):
|
| 19 |
+
obj = getattr(obj, layer)
|
| 20 |
+
else:
|
| 21 |
+
return None
|
| 22 |
+
return obj
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@compatibility(is_backward_compatible=False)
|
| 26 |
+
def setattr_recursive(obj, attr, value):
|
| 27 |
+
if "." not in attr:
|
| 28 |
+
setattr(obj, attr, value)
|
| 29 |
+
else:
|
| 30 |
+
layer = attr.split(".")
|
| 31 |
+
setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@compatibility(is_backward_compatible=False)
|
| 35 |
+
@dataclass
|
| 36 |
+
class Component:
|
| 37 |
+
"""
|
| 38 |
+
A component serves as a container for a subgraph we want to create afterwards.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
graph: torch.fx.Graph
|
| 42 |
+
order: int
|
| 43 |
+
name: str
|
| 44 |
+
|
| 45 |
+
# Stores the placeholder nodes in `graph`.
|
| 46 |
+
input_placeholders: List = field(default_factory=list)
|
| 47 |
+
|
| 48 |
+
# Store the nodes in original graph that are placeholder in `graph`.
|
| 49 |
+
orig_inputs: List = field(default_factory=list)
|
| 50 |
+
|
| 51 |
+
# Store the nodes in original graph that are outputs in `graph`.
|
| 52 |
+
orig_outputs: List = field(default_factory=list)
|
| 53 |
+
|
| 54 |
+
# Mapping from get_attr node in original graph to get_attr node in `graph`.
|
| 55 |
+
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
|
| 56 |
+
constructor_args: List[str] = field(default_factory=list)
|
| 57 |
+
gm: Optional[torch.fx.GraphModule] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@compatibility(is_backward_compatible=False)
|
| 61 |
+
def split_by_tags(
|
| 62 |
+
gm: torch.fx.GraphModule,
|
| 63 |
+
tags: List[str],
|
| 64 |
+
return_fqn_mapping: bool = False,
|
| 65 |
+
return_tuple: bool = False,
|
| 66 |
+
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
|
| 67 |
+
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
|
| 68 |
+
"""
|
| 69 |
+
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
| 70 |
+
tags. For example, we have tags = ["a", "b", "c"], the function will create
|
| 71 |
+
the initial submodules in the order of "a", "b", "c".
|
| 72 |
+
|
| 73 |
+
To set a tag:
|
| 74 |
+
gm.graph.nodes[idx].tag = "mytag"
|
| 75 |
+
|
| 76 |
+
This will result in all nodes with the same tag being extracted and placed in their
|
| 77 |
+
own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
|
| 78 |
+
and output nodes are created when needed while get_attr nodes get copied to submodules
|
| 79 |
+
where they are used.
|
| 80 |
+
|
| 81 |
+
Given the following module def:
|
| 82 |
+
|
| 83 |
+
class SimpleModule(torch.nn.Module):
|
| 84 |
+
def __init__(self):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.linear1 = torch.nn.Linear(...)
|
| 87 |
+
self.linear2 = torch.nn.Linear(...)
|
| 88 |
+
self.linear3 = torch.nn.Linear(...)
|
| 89 |
+
|
| 90 |
+
def forward(self, in1, in2):
|
| 91 |
+
r1 = self.linear1(in1)
|
| 92 |
+
r2 = self.linear2(in2)
|
| 93 |
+
r3 = torch.cat([r1, r2])
|
| 94 |
+
return self.linear3(r3)
|
| 95 |
+
|
| 96 |
+
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
|
| 97 |
+
|
| 98 |
+
ro:
|
| 99 |
+
def forward(self, in1):
|
| 100 |
+
self = self.root
|
| 101 |
+
linear1 = self.linear1(in1)
|
| 102 |
+
return linear1
|
| 103 |
+
|
| 104 |
+
main:
|
| 105 |
+
def forward(self, in2, linear1):
|
| 106 |
+
self = self.root
|
| 107 |
+
linear2 = self.linear2(in2)
|
| 108 |
+
cat_1 = torch.cat([linear1, linear2])
|
| 109 |
+
linear3 = self.linear3(cat_1)
|
| 110 |
+
return linear3
|
| 111 |
+
|
| 112 |
+
main:
|
| 113 |
+
def forward(self, in1, in2):
|
| 114 |
+
self = self.root
|
| 115 |
+
ro_0 = self.ro_0(in1)
|
| 116 |
+
main_1 = self.main_1(in2, ro_0)
|
| 117 |
+
return main_1
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
split_gm: torch fx graph after split
|
| 121 |
+
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
|
| 122 |
+
after split for call_module and get_attr.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def flatten(x: torch.fx.node.Argument) -> NodeList:
|
| 126 |
+
"""
|
| 127 |
+
Stores nodes in x to a list and returns the list.
|
| 128 |
+
"""
|
| 129 |
+
r: NodeList = []
|
| 130 |
+
map_arg(x, r.append)
|
| 131 |
+
return r
|
| 132 |
+
|
| 133 |
+
# Mapping from node in original module to node in created submodule.
|
| 134 |
+
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 135 |
+
|
| 136 |
+
# Mapping from node in original module or created submodules to
|
| 137 |
+
# corresponding component.
|
| 138 |
+
node_to_component: Dict[torch.fx.Node, Component] = {}
|
| 139 |
+
|
| 140 |
+
# Mapping from tag to the corresponding component.
|
| 141 |
+
tag_to_component: Dict[str, Component] = {}
|
| 142 |
+
|
| 143 |
+
# Stores all components.
|
| 144 |
+
all_components: List[Component] = []
|
| 145 |
+
|
| 146 |
+
# Stores nodes that will be used in main graph.
|
| 147 |
+
used_in_main: Dict[torch.fx.Node, None] = {}
|
| 148 |
+
|
| 149 |
+
# Main graph after split.
|
| 150 |
+
main_g = torch.fx.Graph()
|
| 151 |
+
|
| 152 |
+
# Mapping from node in original module to node in main graph after split.
|
| 153 |
+
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
| 154 |
+
|
| 155 |
+
# Output node of original module.
|
| 156 |
+
output_node: Optional[torch.fx.Node] = None
|
| 157 |
+
|
| 158 |
+
# Create a component for each tag, we don't expect to create other components afterwards.
|
| 159 |
+
for tag in tags:
|
| 160 |
+
comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
|
| 161 |
+
all_components.append(comp)
|
| 162 |
+
tag_to_component[tag] = comp
|
| 163 |
+
|
| 164 |
+
# Traverse the nodes in original graph and take care of them.
|
| 165 |
+
for node in gm.graph.nodes:
|
| 166 |
+
if node.op == "output":
|
| 167 |
+
if output_node is not None:
|
| 168 |
+
raise RuntimeError("Multiple output nodes in graph!")
|
| 169 |
+
output_node = node
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Placeholders in the original graph get copied to main graph.
|
| 173 |
+
if node.op == "placeholder":
|
| 174 |
+
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
|
| 175 |
+
main_remapping[node].meta = copy.copy(node.meta)
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
# Get_attr nodes are ignored because we are not tagging them.
|
| 179 |
+
# Instead, we copy them directly to the submodules use them afterwards.
|
| 180 |
+
if node.op == "get_attr":
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
# Now we process callable nodes which are nodes with op of call_module,
|
| 184 |
+
# call_function or call_method. Every callable nodes should be tagged.
|
| 185 |
+
assert hasattr(node, "tag")
|
| 186 |
+
|
| 187 |
+
upstream_components = [
|
| 188 |
+
node_to_component[x]
|
| 189 |
+
for x in flatten(node.args) + flatten(node.kwargs)
|
| 190 |
+
if x.op not in {"placeholder", "get_attr"}
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
comp = tag_to_component[node.tag]
|
| 194 |
+
node_to_component[node] = comp
|
| 195 |
+
|
| 196 |
+
# Max order of upperstream components.
|
| 197 |
+
mx = max((c.order for c in upstream_components), default=0)
|
| 198 |
+
|
| 199 |
+
# Expect the component for `node` has higher order then its upstream components.
|
| 200 |
+
assert comp.order >= mx
|
| 201 |
+
|
| 202 |
+
# Map a input of `node` to nodes in the component's graph.
|
| 203 |
+
def remap_func(x):
|
| 204 |
+
# If input is a get_attr node, copy it to current component's graph.
|
| 205 |
+
# Returns the get_attr node in current component's graph.
|
| 206 |
+
if x.op == "get_attr":
|
| 207 |
+
if x not in comp.getattr_maps:
|
| 208 |
+
comp.getattr_maps[x] = comp.graph.get_attr(
|
| 209 |
+
x.target, type_expr=x.type
|
| 210 |
+
)
|
| 211 |
+
return comp.getattr_maps[x]
|
| 212 |
+
|
| 213 |
+
# If input is not a placeholder, it should have been put into a component
|
| 214 |
+
# already. If it's the current component then we return the corresponding
|
| 215 |
+
# node in the component.
|
| 216 |
+
if x.op != "placeholder" and node_to_component[x] == comp:
|
| 217 |
+
return node_remapping[x]
|
| 218 |
+
|
| 219 |
+
# If input is a placeholder or it's in other components, we want to make it
|
| 220 |
+
# as a placeholder in current component's graph.
|
| 221 |
+
if x not in comp.orig_inputs:
|
| 222 |
+
comp.orig_inputs.append(x)
|
| 223 |
+
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
|
| 224 |
+
placeholder.meta = copy.copy(x.meta)
|
| 225 |
+
comp.input_placeholders.append(placeholder)
|
| 226 |
+
used_in_main[x] = None
|
| 227 |
+
|
| 228 |
+
return comp.input_placeholders[comp.orig_inputs.index(x)]
|
| 229 |
+
|
| 230 |
+
n = comp.graph.node_copy(node, remap_func)
|
| 231 |
+
n.tag = node.tag # type: ignore[attr-defined]
|
| 232 |
+
node_remapping[node] = n
|
| 233 |
+
node_to_component[n] = comp
|
| 234 |
+
|
| 235 |
+
if output_node is None:
|
| 236 |
+
raise RuntimeError("Graph had no output node!")
|
| 237 |
+
|
| 238 |
+
for x in flatten(output_node.args[0]):
|
| 239 |
+
if x.op == "get_attr":
|
| 240 |
+
# We don't need components mapping for nodes of type "get_attr"
|
| 241 |
+
# that are consumed by the output. Only need to make sure we create
|
| 242 |
+
# corresponding counterparts in the resulting graph.
|
| 243 |
+
main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
|
| 244 |
+
else:
|
| 245 |
+
# All component results consumed by the output node should be
|
| 246 |
+
# marked as "used in main".
|
| 247 |
+
used_in_main[x] = None
|
| 248 |
+
|
| 249 |
+
# If a node is used in main graph then we mark it as an output in the component
|
| 250 |
+
# it belongs to.
|
| 251 |
+
for n in used_in_main:
|
| 252 |
+
if n.op != "placeholder":
|
| 253 |
+
node_to_component[n].orig_outputs.append(n)
|
| 254 |
+
|
| 255 |
+
# Now we create a graphmodule for each component.
|
| 256 |
+
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
| 257 |
+
for comp in all_components:
|
| 258 |
+
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
|
| 259 |
+
|
| 260 |
+
if return_tuple:
|
| 261 |
+
comp.graph.output(outs)
|
| 262 |
+
else:
|
| 263 |
+
# Take care of the args of FX output node. If there's a single
|
| 264 |
+
# output then the output node args is like (output_single), else
|
| 265 |
+
# if there're multiple outputs then the output node args is like
|
| 266 |
+
# ((output_0, output_1, ...)).
|
| 267 |
+
comp.graph.output(outs[0] if len(outs) == 1 else outs)
|
| 268 |
+
|
| 269 |
+
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
|
| 270 |
+
gm, subgraph=comp.graph, comp_name=comp.name
|
| 271 |
+
)
|
| 272 |
+
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
|
| 273 |
+
|
| 274 |
+
# Create a call_module node in main graph.
|
| 275 |
+
main_node = main_g.call_module(
|
| 276 |
+
comp.name,
|
| 277 |
+
args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
|
| 278 |
+
kwargs=None,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if len(outs) == 1 and not return_tuple:
|
| 282 |
+
main_remapping[comp.orig_outputs[0]] = main_node
|
| 283 |
+
else:
|
| 284 |
+
for i, o in enumerate(comp.orig_outputs):
|
| 285 |
+
# Use Proxy to record getitem access.
|
| 286 |
+
main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
|
| 287 |
+
|
| 288 |
+
main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
|
| 289 |
+
main_root = HolderModule({comp.name: comp.gm for comp in all_components})
|
| 290 |
+
main_g._codegen = gm.graph._codegen
|
| 291 |
+
|
| 292 |
+
# If the output nodes consumes get_attr directly in the original graph,
|
| 293 |
+
# then we need to make sure get_attr is copied to the new graph.
|
| 294 |
+
for x in flatten(output_node.args[0]):
|
| 295 |
+
if x.op == "get_attr":
|
| 296 |
+
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
|
| 297 |
+
|
| 298 |
+
result_gm = GraphModuleCls(main_root, main_g)
|
| 299 |
+
if return_fqn_mapping:
|
| 300 |
+
return result_gm, orig_to_split_fqn_mapping
|
| 301 |
+
|
| 302 |
+
return result_gm
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (350 Bytes). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc
ADDED
|
Binary file (6.99 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Tuple
|
| 2 |
+
|
| 3 |
+
from torch.fx._compatibility import compatibility
|
| 4 |
+
from torch.fx.graph import Graph
|
| 5 |
+
|
| 6 |
+
from torch.fx.graph_module import GraphModule
|
| 7 |
+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@compatibility(is_backward_compatible=False)
|
| 15 |
+
class HolderModule(Module):
|
| 16 |
+
"""
|
| 17 |
+
HolderModule is used to copy all the attributes from original module to submodules
|
| 18 |
+
that uses the attributes
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, d):
|
| 22 |
+
super().__init__()
|
| 23 |
+
for k, v in d.items():
|
| 24 |
+
self.add_module(k, v)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@compatibility(is_backward_compatible=False)
|
| 28 |
+
def lift_subgraph_as_module(
|
| 29 |
+
gm: GraphModule,
|
| 30 |
+
subgraph: Graph,
|
| 31 |
+
comp_name: str = "",
|
| 32 |
+
class_name: str = "GraphModule",
|
| 33 |
+
) -> Tuple[GraphModule, Dict[str, str]]:
|
| 34 |
+
"""
|
| 35 |
+
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
gm (GraphModule): parent graph module
|
| 39 |
+
|
| 40 |
+
subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
|
| 41 |
+
|
| 42 |
+
comp_name (str): name for the new component
|
| 43 |
+
|
| 44 |
+
class_name (str): name for the submodule
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# Loop through all module calls (call_module) and param fetches (get_attr)
|
| 49 |
+
# in this component, creating HolderModules as necessary to match the path.
|
| 50 |
+
# e.g. if in the original module there's a get_attr node fetches "conv.weight".
|
| 51 |
+
# We create a HolderModule as root -> add a HolderModule named "conv" ->
|
| 52 |
+
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
|
| 53 |
+
# the original module.
|
| 54 |
+
submodule = HolderModule({})
|
| 55 |
+
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
| 56 |
+
for n in subgraph.nodes:
|
| 57 |
+
if n.op not in ("call_module", "get_attr"):
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
target = n.target
|
| 61 |
+
assert isinstance(target, str)
|
| 62 |
+
target_name_parts = target.split(".")
|
| 63 |
+
curr = submodule
|
| 64 |
+
orig_gm = gm
|
| 65 |
+
|
| 66 |
+
for name in target_name_parts[:-1]:
|
| 67 |
+
if not hasattr(curr, name):
|
| 68 |
+
curr.add_module(name, HolderModule({}))
|
| 69 |
+
|
| 70 |
+
curr = getattr(curr, name)
|
| 71 |
+
orig_gm = getattr(orig_gm, name)
|
| 72 |
+
|
| 73 |
+
leaf_node_name = target_name_parts[-1]
|
| 74 |
+
leaf_node = getattr(orig_gm, leaf_node_name)
|
| 75 |
+
|
| 76 |
+
orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
|
| 77 |
+
# Relies on custom __setattr__ magic.
|
| 78 |
+
setattr(curr, leaf_node_name, leaf_node)
|
| 79 |
+
|
| 80 |
+
return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@compatibility(is_backward_compatible=False)
|
| 84 |
+
def compare_graphs(left: Graph, right: Graph) -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Return True if two graphs are identical, i.e they
|
| 87 |
+
- have the same number of outputs in the same order
|
| 88 |
+
- have the same number of inputs in the same order
|
| 89 |
+
- have the same set of nodes, and identical connectivity
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
|
| 93 |
+
matches = matcher.match(right)
|
| 94 |
+
|
| 95 |
+
return len(matches) > 0
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
from torch.fx import (
|
| 6 |
+
Node,
|
| 7 |
+
Graph,
|
| 8 |
+
)
|
| 9 |
+
from torch.fx._compatibility import compatibility
|
| 10 |
+
from typing import Dict, List, Set, Any, Union, Tuple
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
__all__ = ['SubgraphMatcher', 'InternalMatch']
|
| 15 |
+
|
| 16 |
+
# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
|
| 17 |
+
def _init_logger():
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
|
| 21 |
+
logger.setLevel(level)
|
| 22 |
+
console = logging.StreamHandler()
|
| 23 |
+
formatter = logging.Formatter("%(filename)s > %(message)s")
|
| 24 |
+
console.setFormatter(formatter)
|
| 25 |
+
console.setLevel(level)
|
| 26 |
+
# add the handlers to the logger
|
| 27 |
+
logger.addHandler(console)
|
| 28 |
+
logger.propagate = False
|
| 29 |
+
return logger
|
| 30 |
+
|
| 31 |
+
logger = _init_logger()
|
| 32 |
+
|
| 33 |
+
@compatibility(is_backward_compatible=False)
|
| 34 |
+
@dataclass
|
| 35 |
+
class InternalMatch:
|
| 36 |
+
# Nodes from which the match was found
|
| 37 |
+
anchors: List[Node]
|
| 38 |
+
# Maps nodes in the pattern subgraph to nodes in the larger graph
|
| 39 |
+
nodes_map: Dict[Node, Node] = field(default_factory=dict)
|
| 40 |
+
|
| 41 |
+
# nodes in target graph that are matched placeholder in pattern
|
| 42 |
+
placeholder_nodes: List[Node] = field(default_factory=list)
|
| 43 |
+
|
| 44 |
+
# nodes in matched subgraph returned by output
|
| 45 |
+
returning_nodes: List[Node] = field(default_factory=list)
|
| 46 |
+
|
| 47 |
+
# map from a string name to a node in the target graph
|
| 48 |
+
# only available if the matcher is `SubgraphMatcherWithNameNodesMap`
|
| 49 |
+
name_node_map: Dict[str, Node] = field(default_factory=dict)
|
| 50 |
+
|
| 51 |
+
def __copy__(self):
|
| 52 |
+
return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
|
| 53 |
+
placeholder_nodes=self.placeholder_nodes.copy(),
|
| 54 |
+
returning_nodes=self.returning_nodes.copy())
|
| 55 |
+
|
| 56 |
+
@compatibility(is_backward_compatible=False)
|
| 57 |
+
class SubgraphMatcher:
|
| 58 |
+
def __init__(self, pattern: Graph,
|
| 59 |
+
match_output: bool = False,
|
| 60 |
+
match_placeholder: bool = False,
|
| 61 |
+
remove_overlapping_matches: bool = True,
|
| 62 |
+
ignore_literals: bool = False) -> None:
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
pattern: the targeted matching pattern, represented in fx.Graph.
|
| 66 |
+
match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
|
| 67 |
+
If False, output node is ignored during match.
|
| 68 |
+
match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
|
| 69 |
+
the targeted pattern. If False, placeholder nodes will be used a wildcard.
|
| 70 |
+
remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
|
| 71 |
+
will be returned.
|
| 72 |
+
ignore_literals: If True, will not check if literals are equal and
|
| 73 |
+
will instead treat them as wildcards.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
self.pattern = pattern
|
| 77 |
+
self.match_output = match_output
|
| 78 |
+
self.match_placeholder = match_placeholder
|
| 79 |
+
self.remove_overlapping_matches = remove_overlapping_matches
|
| 80 |
+
self.ignore_literals = ignore_literals
|
| 81 |
+
|
| 82 |
+
if len(pattern.nodes) == 0:
|
| 83 |
+
raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
|
| 84 |
+
|
| 85 |
+
for node in pattern.nodes:
|
| 86 |
+
if node.op != "output":
|
| 87 |
+
assert len(node.users) > 0, \
|
| 88 |
+
"SubgraphMatcher cannot be initialized with an pattern with dead code"
|
| 89 |
+
|
| 90 |
+
# TODO: assert pattern is a connected graph
|
| 91 |
+
|
| 92 |
+
self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
|
| 93 |
+
output_node = next(iter(reversed(pattern.nodes)))
|
| 94 |
+
# nodes returned by outputs
|
| 95 |
+
self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
|
| 96 |
+
|
| 97 |
+
self.pattern_anchors: List[Node] = []
|
| 98 |
+
if match_output:
|
| 99 |
+
self.pattern_anchors = [output_node]
|
| 100 |
+
else:
|
| 101 |
+
# If a node has output_node as the ONLY user, then this node is a graph sink,
|
| 102 |
+
# and should be matched against as an anchor
|
| 103 |
+
self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
|
| 104 |
+
|
| 105 |
+
def _match_attributes(self, pn: Node, gn: Node) -> bool:
|
| 106 |
+
# Attributes matching is complicated. Right now we only support matching constant tensor
|
| 107 |
+
assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
|
| 108 |
+
assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
|
| 109 |
+
|
| 110 |
+
# TODO(tmanlaibaatar) should probably make this actual API
|
| 111 |
+
def _getattr(model: torch.fx.GraphModule, attr_name: str):
|
| 112 |
+
*prefix, field = attr_name.split(".")
|
| 113 |
+
t = model
|
| 114 |
+
for item in prefix:
|
| 115 |
+
t = getattr(t, item, None) # type: ignore[assignment]
|
| 116 |
+
assert t is not None
|
| 117 |
+
|
| 118 |
+
return getattr(t, field)
|
| 119 |
+
|
| 120 |
+
pn_value = _getattr(pn.graph.owning_module, pn.target)
|
| 121 |
+
gn_value = _getattr(gn.graph.owning_module, gn.target)
|
| 122 |
+
|
| 123 |
+
if type(pn_value) != type(gn_value):
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
# Don't require exact match on tensor values.
|
| 127 |
+
if isinstance(pn_value, torch.Tensor):
|
| 128 |
+
return isinstance(gn_value, torch.Tensor)
|
| 129 |
+
else:
|
| 130 |
+
raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
|
| 131 |
+
return False
|
| 132 |
+
|
| 133 |
+
def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
|
| 134 |
+
# if exact match for placeholder is not required, then use placeholder as a wildcard
|
| 135 |
+
if not self.match_placeholder and pn.op == "placeholder":
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
if pn.op == gn.op:
|
| 139 |
+
if pn.op == "placeholder" or pn.op == "output":
|
| 140 |
+
return True
|
| 141 |
+
elif pn.op == "get_attr":
|
| 142 |
+
return self._match_attributes(pn, gn)
|
| 143 |
+
return pn.target == gn.target
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
|
| 147 |
+
# `lookup` represents all the nodes in `original_graph`
|
| 148 |
+
# that are part of `pattern`
|
| 149 |
+
|
| 150 |
+
# Placeholders can be used by other nodes in the graphs
|
| 151 |
+
lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
|
| 152 |
+
|
| 153 |
+
for gn, pn in lookup.items():
|
| 154 |
+
# nodes returned by output are allowed to be used in other areas of the graph
|
| 155 |
+
if pn in self.pattern_returning_nodes:
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
for user in gn.users:
|
| 159 |
+
# If this node has users that were not in `lookup`, then it must leak out of the
|
| 160 |
+
# pattern subgraph
|
| 161 |
+
if user not in lookup:
|
| 162 |
+
return False
|
| 163 |
+
return True
|
| 164 |
+
|
| 165 |
+
def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
|
| 166 |
+
non_overlapping_matches: List[InternalMatch] = list()
|
| 167 |
+
nodes_matched: Set[Node] = set()
|
| 168 |
+
|
| 169 |
+
for match in matches:
|
| 170 |
+
found_overlap = False
|
| 171 |
+
for pn, gn in match.nodes_map.items():
|
| 172 |
+
if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
|
| 173 |
+
found_overlap = True
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
if not found_overlap:
|
| 177 |
+
non_overlapping_matches.append(match)
|
| 178 |
+
for pn, gn in match.nodes_map.items():
|
| 179 |
+
if pn.op not in {"placeholder", "output"}:
|
| 180 |
+
nodes_matched.add(gn)
|
| 181 |
+
return non_overlapping_matches
|
| 182 |
+
|
| 183 |
+
def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
|
| 184 |
+
assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
|
| 185 |
+
|
| 186 |
+
if isinstance(pn, Node) and not isinstance(gn, Node):
|
| 187 |
+
if pn.op == "placeholder":
|
| 188 |
+
# Check if we've already matched these nodes in the current
|
| 189 |
+
# traversal
|
| 190 |
+
if pn in match.nodes_map:
|
| 191 |
+
return match.nodes_map[pn] == gn
|
| 192 |
+
|
| 193 |
+
match.nodes_map[pn] = gn
|
| 194 |
+
return True
|
| 195 |
+
else:
|
| 196 |
+
return False
|
| 197 |
+
elif not isinstance(pn, Node) and isinstance(gn, Node):
|
| 198 |
+
return False
|
| 199 |
+
else:
|
| 200 |
+
return type(gn) == type(pn) and gn == pn
|
| 201 |
+
|
| 202 |
+
def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
|
| 203 |
+
logger.info(" matching %s to %s", pn, gn)
|
| 204 |
+
|
| 205 |
+
assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
|
| 206 |
+
|
| 207 |
+
# Check if we've already matched these nodes in the current
|
| 208 |
+
# traversal
|
| 209 |
+
if pn in match.nodes_map:
|
| 210 |
+
return match.nodes_map[pn] == gn
|
| 211 |
+
|
| 212 |
+
# TODO: use a more efficient way to check if gn is matched before: two-way dict
|
| 213 |
+
if gn in match.nodes_map.values():
|
| 214 |
+
return False
|
| 215 |
+
|
| 216 |
+
if not self._nodes_are_equal(pn, gn):
|
| 217 |
+
return False
|
| 218 |
+
|
| 219 |
+
# Optimistically mark `pn` as a match for `gn`, and save a local copy of match
|
| 220 |
+
saved_match = copy.copy(match)
|
| 221 |
+
match.nodes_map[pn] = gn
|
| 222 |
+
|
| 223 |
+
# Placeholder is a wildcard and can be matched with any python object
|
| 224 |
+
# (including list/tuple)
|
| 225 |
+
if pn.op == "placeholder":
|
| 226 |
+
return True
|
| 227 |
+
|
| 228 |
+
# Recursively traverse upwards to check if `pn` is a true
|
| 229 |
+
# match for `gn`
|
| 230 |
+
match_found = True
|
| 231 |
+
|
| 232 |
+
def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
|
| 233 |
+
if len(args1) != len(args2):
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
for a1, a2 in zip(args1, args2):
|
| 237 |
+
if isinstance(a1, Node) and isinstance(a2, Node):
|
| 238 |
+
matched = self._match_nodes(a1, a2, match)
|
| 239 |
+
elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
|
| 240 |
+
matched = _match_args(a1, a2)
|
| 241 |
+
else:
|
| 242 |
+
matched = self._match_literals(a1, a2, match) or self.ignore_literals
|
| 243 |
+
|
| 244 |
+
if not matched:
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
return True
|
| 248 |
+
|
| 249 |
+
# Flatten all args/kwargs into 1 list of args
|
| 250 |
+
pn_args, gn_args = None, None
|
| 251 |
+
if (
|
| 252 |
+
(len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and
|
| 253 |
+
pn.op == "call_function" and
|
| 254 |
+
isinstance(pn.target, torch._ops.OpOverload)
|
| 255 |
+
):
|
| 256 |
+
args_schema = pn.target._schema.arguments
|
| 257 |
+
|
| 258 |
+
def get_all_arguments(orig_args, orig_kwargs):
|
| 259 |
+
all_args = []
|
| 260 |
+
for i, schema in enumerate(args_schema):
|
| 261 |
+
if schema.name in orig_kwargs:
|
| 262 |
+
all_args.append(orig_kwargs[schema.name])
|
| 263 |
+
elif not schema.kwarg_only and i < len(orig_args):
|
| 264 |
+
all_args.append(orig_args[i])
|
| 265 |
+
else:
|
| 266 |
+
all_args.append(schema.default_value)
|
| 267 |
+
return all_args
|
| 268 |
+
|
| 269 |
+
pn_args = get_all_arguments(pn.args, pn.kwargs)
|
| 270 |
+
gn_args = get_all_arguments(gn.args, gn.kwargs)
|
| 271 |
+
|
| 272 |
+
elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()):
|
| 273 |
+
pn_args = list(pn.args)
|
| 274 |
+
gn_args = list(gn.args)
|
| 275 |
+
pn_args.extend(list(pn.kwargs.values()))
|
| 276 |
+
gn_args.extend(list(gn.kwargs.values()))
|
| 277 |
+
else:
|
| 278 |
+
match_found = False
|
| 279 |
+
|
| 280 |
+
match_found = (
|
| 281 |
+
match_found and
|
| 282 |
+
pn_args is not None and
|
| 283 |
+
gn_args is not None and
|
| 284 |
+
_match_args(pn_args, gn_args)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
if not match_found:
|
| 288 |
+
# revert to saved_match before matching with current node
|
| 289 |
+
match = copy.copy(saved_match)
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
return True
|
| 293 |
+
|
| 294 |
+
def match(self, graph: Graph) -> List[InternalMatch]:
|
| 295 |
+
"""
|
| 296 |
+
Returns:
|
| 297 |
+
The matched subgraphs.
|
| 298 |
+
Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
|
| 299 |
+
and nodes returned by output) can only be consumed by nodes within the matched subgraph.
|
| 300 |
+
|
| 301 |
+
Subgraph pattern matcher is implemented with the backtracking style in the following steps:
|
| 302 |
+
|
| 303 |
+
1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
|
| 304 |
+
are the "sinks" (nodes with no user other than the output node) of the pattern graph.
|
| 305 |
+
One pattern graph could have multiple anchors if it has multiple return values.
|
| 306 |
+
|
| 307 |
+
2. In the target graph, we identify the potential candidate nodes that can be matched
|
| 308 |
+
with each anchor. These anchor-candidate pairs are the starting points for
|
| 309 |
+
pairwise per-node matching.
|
| 310 |
+
|
| 311 |
+
3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
|
| 312 |
+
pattern and target graphs. For every pattern nodes along traversal path, we compare it
|
| 313 |
+
against the target nodes. In case any comparison failed, the match for this anchor-candidate
|
| 314 |
+
pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
|
| 315 |
+
for more details.
|
| 316 |
+
|
| 317 |
+
4. In the case of multiple anchors, every anchor will need to find a match using step 3.
|
| 318 |
+
In addition, the matches found between anchors need to have a common intersection node
|
| 319 |
+
in order for the match to be valid. This is implemented with backtracking. See `backtracking`
|
| 320 |
+
for more details.
|
| 321 |
+
|
| 322 |
+
Notice: graph traversal must be done in the reverser order because a tensor can have multiple
|
| 323 |
+
consumers, but can only have a single producer. Only with reverser order, we can we jointly
|
| 324 |
+
traverse the pattern and target graph in a deterministic path.
|
| 325 |
+
|
| 326 |
+
Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
|
| 327 |
+
in practice, it's unlikely to blow up.
|
| 328 |
+
|
| 329 |
+
"""
|
| 330 |
+
from torch.fx.passes.utils.fuser_utils import validate_partition
|
| 331 |
+
|
| 332 |
+
# find candidate nodes to match with pattern anchors
|
| 333 |
+
match_candidates: Dict[Node, List[Node]] = defaultdict(list)
|
| 334 |
+
for pattern_anchor in self.pattern_anchors:
|
| 335 |
+
for node in graph.nodes:
|
| 336 |
+
if self._nodes_are_equal(pattern_anchor, node):
|
| 337 |
+
match_candidates[pattern_anchor].append(node)
|
| 338 |
+
match_candidates_list = list(match_candidates.items())
|
| 339 |
+
|
| 340 |
+
logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
|
| 341 |
+
|
| 342 |
+
matches: List[InternalMatch] = []
|
| 343 |
+
|
| 344 |
+
def backtracking(anchor_index, match):
|
| 345 |
+
if anchor_index == len(match_candidates_list):
|
| 346 |
+
match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
|
| 347 |
+
match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
|
| 348 |
+
matches.append(match)
|
| 349 |
+
|
| 350 |
+
logger.info("Found a match: %s\n", match)
|
| 351 |
+
return
|
| 352 |
+
|
| 353 |
+
pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
|
| 354 |
+
saved_match = copy.copy(match)
|
| 355 |
+
|
| 356 |
+
for node in candidate_nodes:
|
| 357 |
+
logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
|
| 358 |
+
|
| 359 |
+
match_found = self._match_nodes(pattern_anchor, node, match)
|
| 360 |
+
if match_found:
|
| 361 |
+
# match next anchor
|
| 362 |
+
backtracking(anchor_index + 1, match)
|
| 363 |
+
else:
|
| 364 |
+
logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node)
|
| 365 |
+
|
| 366 |
+
# revert to saved_match before matching with current anchor
|
| 367 |
+
match = copy.copy(saved_match)
|
| 368 |
+
|
| 369 |
+
match = InternalMatch(anchors=self.pattern_anchors)
|
| 370 |
+
if match_candidates_list:
|
| 371 |
+
backtracking(0, match)
|
| 372 |
+
|
| 373 |
+
# filter out the matches where the subgraph is not fully_contained
|
| 374 |
+
before = len(matches)
|
| 375 |
+
matches = [match for match in matches if self._is_contained(match.nodes_map)]
|
| 376 |
+
after = len(matches)
|
| 377 |
+
if before != after:
|
| 378 |
+
logger.info("Filtered out %s matches because they are not fully contained", before - after)
|
| 379 |
+
|
| 380 |
+
# filter out the matches that form a cycle if the subgraph is fused
|
| 381 |
+
valid_matches = []
|
| 382 |
+
for match in matches:
|
| 383 |
+
matched_compute_nodes = \
|
| 384 |
+
[gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
|
| 385 |
+
if validate_partition(matched_compute_nodes):
|
| 386 |
+
valid_matches.append(match)
|
| 387 |
+
if len(valid_matches) != len(matches):
|
| 388 |
+
logger.info("Filtered out %s matches because \
|
| 389 |
+
matched subgraph would form a cycle if fused", len(matches) - len(valid_matches))
|
| 390 |
+
|
| 391 |
+
if self.remove_overlapping_matches:
|
| 392 |
+
before = len(valid_matches)
|
| 393 |
+
matches = self._remove_overlapping_matches(valid_matches)
|
| 394 |
+
after = len(matches)
|
| 395 |
+
if before != after:
|
| 396 |
+
logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after)
|
| 397 |
+
|
| 398 |
+
logger.info("Matches returned: %s", matches)
|
| 399 |
+
|
| 400 |
+
return matches
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc
ADDED
|
Binary file (2.89 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .quantize import * # noqa: F403
|
| 2 |
+
from .observer import * # noqa: F403
|
| 3 |
+
from .qconfig import * # noqa: F403
|
| 4 |
+
from .fake_quantize import * # noqa: F403
|
| 5 |
+
from .fuse_modules import fuse_modules
|
| 6 |
+
from .stubs import * # noqa: F403
|
| 7 |
+
from .quant_type import * # noqa: F403
|
| 8 |
+
from .quantize_jit import * # noqa: F403
|
| 9 |
+
|
| 10 |
+
# from .quantize_fx import *
|
| 11 |
+
from .quantization_mappings import * # noqa: F403
|
| 12 |
+
from .fuser_method_mappings import * # noqa: F403
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def default_eval_fn(model, calib_data):
|
| 16 |
+
r"""
|
| 17 |
+
Default evaluation function takes a torch.utils.data.Dataset or a list of
|
| 18 |
+
input Tensors and run the model on the dataset
|
| 19 |
+
"""
|
| 20 |
+
for data, target in calib_data:
|
| 21 |
+
model(data)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"QuantWrapper",
|
| 26 |
+
"QuantStub",
|
| 27 |
+
"DeQuantStub",
|
| 28 |
+
# Top level API for eager mode quantization
|
| 29 |
+
"quantize",
|
| 30 |
+
"quantize_dynamic",
|
| 31 |
+
"quantize_qat",
|
| 32 |
+
"prepare",
|
| 33 |
+
"convert",
|
| 34 |
+
"prepare_qat",
|
| 35 |
+
# Top level API for graph mode quantization on TorchScript
|
| 36 |
+
"quantize_jit",
|
| 37 |
+
"quantize_dynamic_jit",
|
| 38 |
+
"_prepare_ondevice_dynamic_jit",
|
| 39 |
+
"_convert_ondevice_dynamic_jit",
|
| 40 |
+
"_quantize_ondevice_dynamic_jit",
|
| 41 |
+
# Top level API for graph mode quantization on GraphModule(torch.fx)
|
| 42 |
+
# 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
| 43 |
+
# 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
| 44 |
+
"QuantType", # quantization type
|
| 45 |
+
# custom module APIs
|
| 46 |
+
"get_default_static_quant_module_mappings",
|
| 47 |
+
"get_static_quant_module_class",
|
| 48 |
+
"get_default_dynamic_quant_module_mappings",
|
| 49 |
+
"get_default_qat_module_mappings",
|
| 50 |
+
"get_default_qconfig_propagation_list",
|
| 51 |
+
"get_default_compare_output_module_list",
|
| 52 |
+
"get_quantized_operator",
|
| 53 |
+
"get_fuser_method",
|
| 54 |
+
# Sub functions for `prepare` and `swap_module`
|
| 55 |
+
"propagate_qconfig_",
|
| 56 |
+
"add_quant_dequant",
|
| 57 |
+
"swap_module",
|
| 58 |
+
"default_eval_fn",
|
| 59 |
+
# Observers
|
| 60 |
+
"ObserverBase",
|
| 61 |
+
"WeightObserver",
|
| 62 |
+
"HistogramObserver",
|
| 63 |
+
"observer",
|
| 64 |
+
"default_observer",
|
| 65 |
+
"default_weight_observer",
|
| 66 |
+
"default_placeholder_observer",
|
| 67 |
+
"default_per_channel_weight_observer",
|
| 68 |
+
# FakeQuantize (for qat)
|
| 69 |
+
"default_fake_quant",
|
| 70 |
+
"default_weight_fake_quant",
|
| 71 |
+
"default_fixed_qparams_range_neg1to1_fake_quant",
|
| 72 |
+
"default_fixed_qparams_range_0to1_fake_quant",
|
| 73 |
+
"default_per_channel_weight_fake_quant",
|
| 74 |
+
"default_histogram_fake_quant",
|
| 75 |
+
# QConfig
|
| 76 |
+
"QConfig",
|
| 77 |
+
"default_qconfig",
|
| 78 |
+
"default_dynamic_qconfig",
|
| 79 |
+
"float16_dynamic_qconfig",
|
| 80 |
+
"float_qparams_weight_only_qconfig",
|
| 81 |
+
# QAT utilities
|
| 82 |
+
"default_qat_qconfig",
|
| 83 |
+
"prepare_qat",
|
| 84 |
+
"quantize_qat",
|
| 85 |
+
# module transformations
|
| 86 |
+
"fuse_modules",
|
| 87 |
+
]
|