diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py new file mode 100644 index 0000000000000000000000000000000000000000..301cf42beb062dd5ad9763507417de57fcc6e48d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/db/gen_example.py @@ -0,0 +1,28 @@ +import os +import sys + +import torch._export.db.examples as examples + +TEMPLATE = '''import torch + +from torch._export.db.case import export_case + + +@export_case( + example_inputs=(torch.randn(3, 2),), + tags={{}}, +) +def {case_name}(x): + """ + """ + + return +''' + +if __name__ == "__main__": + assert len(sys.argv) == 2 + root_dir = examples.__name__.replace(".", "/") + assert os.path.exists(root_dir) + with open(os.path.join(root_dir, sys.argv[1] + ".py"), "w") as f: + print("Writing to", f.name, "...") + f.write(TEMPLATE.format(case_name=sys.argv[1])) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..5d28ea31549087f2b118e6c431d812da314e6497 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/exported_program.py @@ -0,0 +1,50 @@ +import warnings + + +import torch +import torch.fx + + +# TODO(ycao): This is added to avoid breaking existing code temporarily. +# Remove when migration is done. +from torch.export.graph_signature import ( + ExportBackwardSignature, + ExportGraphSignature, +) + +from torch.export.exported_program import ( + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) + + + +__all__ = [ + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", +] + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter." + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf9bcefcfc6e3071c97f0dff3c9fad5f2cbdfa8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py @@ -0,0 +1,258 @@ +import inspect +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple, Union + +import torch +from torch._dynamo.source import ( + AttrSource, + GetItemSource, + LocalSource, + TensorProperty, + TensorPropertySource, +) +from torch._dynamo.variables.builder import TrackedFake +from torch._export.passes.add_runtime_assertions_for_constraints_pass import InputDim +from torch._guards import Source +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.export import Constraint +from torch.export.graph_signature import CustomObjArgument +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + DimDynamic, + EqualityConstraint, + ShapeEnv, + StatelessSymbolicContext, +) +from torch.utils._pytree import ( + GetAttrKey, + KeyPath, + MappingKey, + SequenceKey, + tree_map_with_path, +) + + +def key_path_to_source(kp: KeyPath) -> Source: + """ + Given a key path, return the source for the key path. + """ + source: Source = LocalSource("args") + for k in kp: + if isinstance(k, SequenceKey): + source = GetItemSource(source, k.idx) + elif isinstance(k, MappingKey): + source = GetItemSource(source, k.key) + elif isinstance(k, GetAttrKey): + source = AttrSource(source, k.name) + else: + raise ValueError(f"Unknown KeyEntry {k}") + + return source + + +def _is_constant_argument(t): + return t is None or isinstance(t, (int, float, bool, str)) + + +def fakify( + mode: FakeTensorMode, + kp: KeyPath, + t: Any, + t_constraints: Dict[int, Dict[int, Constraint]], + sources: Dict[Tuple[int, int], List[Source]], +): + source = key_path_to_source(kp) + if _is_constant_argument(t) or isinstance(t, torch.ScriptObject): + return t + if not isinstance(t, torch.Tensor): + raise ValueError(f"Unsupported input type {type(t)}") + n_dims = len(t.shape) + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.STATIC] * n_dims, + constraint_sizes=[None] * n_dims, + ) + t_id = id(t) + if t_id in t_constraints: + for i, constraint in t_constraints[t_id].items(): + symbolic_context.constraint_sizes[i] = constraint.constraint_range + symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC + src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) + sources[(t_id, i)].append(src) + mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name + fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) + mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) + return fake + + +def make_fake_params_buffers( + fake_mode: FakeTensorMode, + params_buffers: Dict[str, torch.Tensor], +) -> Dict[str, Union[torch.Tensor, torch.nn.Parameter]]: + faked_params_buffers = {} + for key, value in params_buffers.items(): + faked_params_buffers[key] = fake_mode.from_tensor(value, static_shapes=True) + return faked_params_buffers + + +def make_fake_inputs(nn_module, args, kwargs, constraints): + """ + Given an nn module, example inputs, and constraints, return a new fake mode, + fake inputs created in that mode whose dynamic shape dimensions are constrained + by the given ranges, and sources for pairs of dynamic shape dimensions that are + constrained to be equal. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following pre-tracing steps: + # - Fakify inputs. + # - Process input shape equalities. + # In strict, these steps are spread across multiple files: + # - output_graph.py fakifies inputs. + # - [post-tracing] guards.py processes input shape equalities. + + t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) + for constraint in constraints: + t_constraints[constraint.t_id][constraint.dim] = constraint + if constraint.shared is not None: + t_constraints[constraint.shared.t_id][constraint.shared.dim] = constraint + + code = nn_module.forward.__code__ + co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + + fake_mode = FakeTensorMode( + shape_env=ShapeEnv(tracked_fakes=[], co_fields=co_fields), + allow_non_fake_inputs=True, + ) + if fake_mode.shape_env is None or fake_mode.shape_env.tracked_fakes is None: + raise ValueError( + "Detected fake_mode does not have a shape_env with tracked fakes. " + "If you constructed the module under a FakeTensorMode, " + "please initialize it like: FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[]))" + ) + + with fake_mode: + original_signature = inspect.signature(nn_module.forward) + sources: Dict[Tuple[int, int], List[Source]] = defaultdict(list) + fake_args, fake_kwargs = tree_map_with_path( + lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources), + (args, kwargs), + ) + + from sympy import Symbol + + source_pairs: List[Tuple[Source, Source]] = [] + derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] + phantom_symbols: Dict[str, Symbol] = {} + for constraint in constraints: + torch.export.dynamic_shapes._process_equalities( + constraint, + lambda t_id, dim: sources[(t_id, dim)], + fake_mode.shape_env, + source_pairs, + derived_equalities, + phantom_symbols, + ) + + equalities_inputs = EqualityConstraint( + source_pairs=source_pairs, + derived_equalities=derived_equalities, + phantom_symbols=list(phantom_symbols.values()), + warn_only=False, + ) + return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature + + +def make_constraints( + fake_mode, + equalities_inputs, + original_signature, + gm, +): + """ + Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, + and a graph module, produce guards on the fake mode's shape env (raising constraint + violations if any), solve (to suggest simplifications or fixes), and return the + resulting range constraints and equality constraints. + """ + # TODO(avik): refactor Dynamo to avoid duplication of the following code + # between non-strict and strict. + # Specifically, here (non-strict) we do the following post-tracing steps: + # - Produce guards. + # - Solve constraints. + # - Install shape metadata in IR. + # In strict, these steps are spread across multiple files: + # - guards.py produces guards. + # - eval_frame.py solves constraints + # - _trace.py installs shape metadata in IR. + + shape_env = fake_mode.shape_env + placeholders = [tf.fake for tf in shape_env.tracked_fakes] + sources = [tf.source for tf in shape_env.tracked_fakes] + input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] + constraint_violation_error = None + try: + shape_env.produce_guards( + placeholders, + sources, + input_contexts=input_contexts, + equalities_inputs=equalities_inputs, + ignore_static=False, + ) + except ConstraintViolationError as e: + constraint_violation_error = e + + shape_env.frozen = True + dim_constraints = shape_env.dim_constraints + if dim_constraints is None: + # Expected when shape_env.produce_guards throws an early constraint violation error. + # There is nothing to solve for in this case. + # TODO(avik): Maybe record the constraint violation error instead and replay later? + assert constraint_violation_error + raise constraint_violation_error + dim_constraints.solve() + dim_constraints.remove_redundant_dynamic_results() + forced_specializations = dim_constraints.forced_specializations() + msg = dim_constraints.prettify_results( + original_signature, constraint_violation_error, forced_specializations + ) + if constraint_violation_error: + constraint_violation_error.args = (constraint_violation_error.args[0] + msg,) + elif forced_specializations: + constraint_violation_error = ConstraintViolationError(msg) + if constraint_violation_error: + raise constraint_violation_error + + range_constraints = {} + input_dims = defaultdict(list) + free_symbols = set() + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + if _is_constant_argument(node.meta["val"]) or isinstance( + node.meta["val"], CustomObjArgument + ): + continue + for i, d in enumerate(node.meta["val"].shape): + if isinstance(d, torch.SymInt): + # Look up the range constraint for the symbol corresponding to this shape dimension + # and store it indexed by the symbolic expression corresponding to it. + # NOTE(avik): Use node._expr instead of node.expr for the lookup here because + # we want the symbol, not its replacement, which could be an expression. Maybe + # there's a better way to do this, e.g., by (re)computing value ranges for expressions? + range_constraints[d.node.expr] = shape_env.var_to_range[d.node._expr] + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) + free_symbols.update(d.node.expr.free_symbols) + + for symbol in free_symbols: + if symbol not in range_constraints: + # Placeholders can have symbolic shapes that are derived expressions. + # The above code will record direct range constraints for them + # so that we can do runtime assertions. In addition, for serde checks + # we want to record range constraints for their root symbols. + range_constraints[symbol] = shape_env.var_to_range[symbol] + + return range_constraints diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..65b1b5e514eca103bf8750dad214d1cd53238b22 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_export/verifier.py @@ -0,0 +1,416 @@ +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, Dict, final, List, Optional, Tuple, Type + +import torch +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.exported_program import ExportedProgram +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymIntArgument, + TensorArgument, +) +from torch.fx import GraphModule +from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)): + return True + elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +class _VerifierMeta(type): + _registry: Dict[str, Type['Verifier']] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> List: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + math.ceil, + math.floor, + ] + + def allowed_op_types(self) -> Tuple[Type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: + return (torch.fx.GraphModule,) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + pass + + @final + def check(self, ep: ExportedProgram) -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_getattr_types() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> List: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> Tuple[Type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled + + ) + + if not isinstance(op, _allowed_op_types()): + if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + if not is_functional(op): + raise SpecViolationError( + f"operator '{op}' is not functional" + ) + self.check_valid_op(op) + + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + if type(attr).__name__ == "LoweredBackendModule": + if _is_type("backend_id", str) \ + and _is_type("processed_bytes", bytes) \ + and _is_type("compile_specs", list) \ + and hasattr(attr, "original_module"): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + + if not isinstance(attr, _allowed_getattr_types()): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"Valid get_attr types: {_allowed_getattr_types()}" + ) + + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.user_inputs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError( + f"Parameter {param} is not in the state dict." + ) + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if input_spec.persistent is True and buffer not in exported_program.state_dict: + raise SpecViolationError( + f"Buffer {buffer} is not in the state dict." + ) + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError( + f"Unknown InputKind {input_spec.kind}." + ) + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens + mutate_nodes: List[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n") + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> Optional[Type[Verifier]]: + if dialect == "ATEN": + return _VerifierMeta._registry.get(dialect) + return _VerifierMeta._registry[dialect] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..07f1055ee82783643bf5e57c8713d90aa1d15df6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/closure.py @@ -0,0 +1,134 @@ +import os +import threading +from queue import Empty as EmptyQueue, Queue + +from torch._lazy.device_context import get_device_context + + +class ClosureHandler: + def __init__(self): + pass + + def run(self, closure): + """Run closure function + + Args: + closure: callable function to run + """ + closure() + + def __call__(self, closures): + for closure in closures: + self.run(closure) + + +class AsyncClosureHandler(ClosureHandler): + """Handler for Asynchronous Step Closures + Args: + max_queue_size: The maximum length of the closure queue after which + the training loop will block until closures are evaluated. + By default, a reasonable limit of a maximum of 100 on the queue. + This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment + variable. + """ + + def __init__(self, max_queue_size=100): + super().__init__() + self._closure_queue: Queue = Queue( + int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) + ) + self._closure_exception: Queue = Queue() + self._closure_lock = threading.Lock() + self._closure_event_loop_finished = threading.Event() + self._closure_event_loop = None + + def start_event_loop(self): + """Start closure event loop if not started""" + if self._closure_event_loop is None: + + def event_loop(): + # Run loop until closure event is set and closure queue is empty + while True: + try: + closure = self._closure_queue.get(block=True, timeout=3) + closure() + self._closure_queue.task_done() + except EmptyQueue: + with self._closure_lock: + if self._closure_queue.empty(): + self._closure_event_loop_finished.set() + return + except Exception as e: + self._closure_exception.put(e) + return + + self._closure_event_loop = threading.Thread(target=event_loop) + self._closure_event_loop.start() + + def run(self, closure): + with self._closure_lock: + self._closure_queue.put(closure, block=True) + if ( + self._closure_event_loop is None + or not self._closure_event_loop.is_alive() + ): + try: + e = self._closure_exception.get(block=False) + raise RuntimeError( + "Cannot run asynchronous closure due to previously raised exception" + ) from e + except EmptyQueue: + self._closure_event_loop = None + self.start_event_loop() + + +def add_step_closure(closure, args=(), run_async=False): + """Adds a closure to the list of the ones to be run at the end of the step. + Many times during model training there is the need to print/report (print to + console, post to tensorboard, etc...) information which require the content of + intermediary tensors to be inspected. + Inspecting different tensors content in different points of the model code + requires many executions and typically causes performance issues. + Adding a step closure will ensure that it will be run after the barrier, when + all the live tensors will be already materialized to device data. + Live tensors which will include the ones captured by the closure arguments. + So using `add_step_closure()` will ensure a single execution will be + performed, even when multiple closures are queued, requiring multiple tensors + to be inspected. + Step closures will be run sequentially in the order they have been queued. + Note that even though using this API the execution will be optimized, it is + advised to throttle the printing/reporting events once every N steps. + Args: + closure (callable): The function to be called. + args (tuple): The arguments to be passed to the closure. + run_async: If True, run the closure asynchronously. + """ + devctx = get_device_context() + closures_type = "async_step_closures" if run_async else "step_closures" + step_closures = getattr(devctx, closures_type, None) + if step_closures is None: + step_closures = [] + setattr(devctx, closures_type, step_closures) + step_closures.append(lambda a=args: closure(*a)) + + +def run_step_closures(): + devctx = get_device_context() + async_step_closures = getattr(devctx, "async_step_closures", None) + if async_step_closures is not None: + devctx.async_step_closures = [] + async_closure_handler = getattr(devctx, "async_closure_handler", None) + if async_closure_handler is None: + async_closure_handler = AsyncClosureHandler() + devctx.async_closure_handler = async_closure_handler + async_closure_handler(async_step_closures) + + step_closures = getattr(devctx, "step_closures", None) + if step_closures is not None: + devctx.step_closures = [] + closure_handler = getattr(devctx, "closure_handler", None) + if closure_handler is None: + closure_handler = ClosureHandler() + devctx.closure_handler = closure_handler + closure_handler(step_closures) + return devctx diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py new file mode 100644 index 0000000000000000000000000000000000000000..27b73c42e5c0de39e5112f717796cfce5d808bc1 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/computation.py @@ -0,0 +1,26 @@ +import torch._C._lazy +import torch._C._lazy_ts_backend + + +def get_tensors_ts_device_data_node(tensors): + """Return tensor ids and eager tensors for DeviceData nodes in the + IR for the passed in lazy tensors. + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + + +def get_graph_hash(tensors): + """Return the graph hash for the passed in lazy tensors""" + return torch._C._lazy._get_graph_hash(tensors) + + +def run_cached_graph(hash_str, graph_inputs): + """Running the cached computation graph with the given inputs + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..286aa049280c9d9555f64042f35b4a5fd57d0059 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/debug.py @@ -0,0 +1,21 @@ +import torch._C._lazy + + +def render_ir_graph(tensors): + """Return a text dump of the LTC IR graph in dot format for the tensors. + The text can be processed by tools like dot to be rendered in pdf,png etc.""" + return torch._C._lazy._get_tensors_dot(tensors) + + +def dump_ir(tensors, ir_format): + """Return a dump of the tensors in the specified format. + Valid format are + - text: for LTC IR + - backend: for the activate backend IR + """ + if ir_format == "text": + return torch._C._lazy._get_tensors_text(tensors) + elif ir_format == "backend": + return torch._C._lazy._get_tensors_backend(tensors) + else: + raise RuntimeError(f"Unrecognized IR format: {ir_format}") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..184223771932d80274e479a39c829300c9c872a7 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,6 @@ +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcfeb17ff1cf8b9f98565841b17998a4b9c790ea Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/error.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5636b9981e6fb56524bc0b8e294dba7daf5c66c2 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/__pycache__/graphs.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..003cca1f0cd4e910d0c042cf94501aa17a1bb2f5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__init__.py @@ -0,0 +1,11 @@ +from .autocast_mode import autocast, custom_bwd, custom_fwd +from .common import amp_definitely_not_available +from .grad_scaler import GradScaler + +__all__ = [ + "amp_definitely_not_available", + "autocast", + "custom_bwd", + "custom_fwd", + "GradScaler", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d325548540f75a40349b3bf18f5c8c5b1d70e409 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/cuda/amp/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b851646f4d72370d9dbe3506a85d4d48c0475a82 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421ac06726bbc73fa9c36e9ff9344e3a2e4158a6 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..098af08994df2b05632a335bd219854624b2ad0b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..939e8b04174466269b1805ecef4af1d230f64a7c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3017395ee54083280183bb1dfdd89222116149a0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe5c3e841476ffbb188d3d052e702c40775ad3e Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a2424d76bf8d4f3035fc43e6d42125d41ef28b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0d23d0187490834615d67257e8855f26fdbbc5 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,557 @@ +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ + op_mod, op_gt, op_lt, op_neq, op_eq +from torch.fx.tensor_type import TensorType, Dyn + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjunction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f'And({self.conjucts})' + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + else: + return False + + def __repr__(self): + return f'Or({self.disjuncts})' + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f'Product({self.products})' + + +class T(Constraint): + """ + True + """ + def __init__(self): + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return 'True' + +class F(Constraint): + """ + False + """ + def __init__(self): + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return 'False' + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string representing the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + else: + return False + + def __repr__(self): + return f'({self.lhs} {self.op} {self.rhs})' + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \ + (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the outout + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}⊔*{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}⊔{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f'can-reshape({self.src}, {self.target})' + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + output: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'IndexSelect({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.dim_replace}, ' \ + f'{self.index})' + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return self.tensor_size == other.tensor_size and \ + self.dim_replace == other.dim_replace and \ + self.index == other.index and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class Transpose(Constraint): + + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'Transpose({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.index1}, ' \ + f'{self.index2})' + + def __eq__(self, other): + if isinstance(other, Transpose): + return self.tensor_size == other.tensor_size and \ + self.index1 == other.index1 and \ + self.index2 == other.index2 and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class GetItem(Constraint): + + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' + + def __eq__(self, other): + if isinstance(other, GetItem): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index == other.index and \ + self.input_var == other.input_var + else: + return False + +class GetItemTensor(Constraint): + + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index_tuple == other.index_tuple and \ + self.input_var == other.input_var + else: + return False + +class CalcConv(Constraint): + + def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output chanel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.conv_result} =' \ + f' calc-conv({self.input_var},' \ + f' {self.c_out}, {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcConv): + return self.conv_result == other.conv_result and self.input_var == other.input_var and \ + self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class CalcMaxPool(Constraint): + + def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.maxpool_result} =' \ + f' calc-maxpool({self.input_var},' \ + f' {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ + and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return self.res1 == other.res1 \ + and self.res2 == other.res2 \ + and self.input1 == other.input1 \ + and self.input2 == other.input2 + else: + return False + + def __repr__(self): + return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param flattened: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return self.start == other.start and self.end == other.end and \ + self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened + + else: + return False + + def __repr__(self): + return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f'TV({self.tvar})' + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'DV({self.c})' + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'BV({self.c})' + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, (BVar, Conj, Disj)) + +def is_dim(d): + return isinstance(d, (DVar, int)) or d == Dyn diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..031562393edcecf8490a34669d04de01b166e759 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1279 @@ +import torch +import operator +import warnings +from typing import Callable, Dict, Iterable + +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ + Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ + TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.operation import \ + op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul +from torch.fx.node import Target, Node +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ + gen_bvar + +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.conv import Conv2d +from torch.nn.modules.batchnorm import BatchNorm2d + +_INFERENCE_RULES: Dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +def register_inference_rule(call_target): + def register(fn): + if call_target in _INFERENCE_RULES: + raise RuntimeError(f'Inference rule already registered for {call_target}!') + _INFERENCE_RULES[call_target] = fn + return fn + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == 'device': + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError('Not yet implemented') + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq)]) + + input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) + + input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) + + consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), + *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, (Node, int)) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) + + constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] + constraints += c + + return constraints, counter + + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +@register_inference_rule(torch.ones) +@register_inference_rule(torch.zeros) +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + output, counter = gen_tvar(counter) + symbols[n] = output + + if isinstance(n.args[0], Node): + input = symbols[n.args[0]] + if isinstance(input, TVar): + return [BinConstraintT(input, output, op_eq)], counter + + # then we have dimension variables + else: + for arg in n.args: + assert isinstance(symbols[arg], DVar) + my_size = [symbols[arg] for arg in n.args] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + + elif isinstance(n.args[0], tuple): + # then the tuple is the size + assert len(n.args[0]) <= 4 + my_size = [symbols[arg] for arg in n.args[0]] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) + + # or input is a tensor and we actually do the replacement + c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq)], counter + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implement the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) + else: + raise NotImplementedError('Not yet implemented') + + +@register_inference_rule(torch.nn.functional.embedding) +def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + embedding_dim_weights = symbols[n.args[1]] + + # will treat this as a static shape. So we will not use matching. + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) + embedding_dim = weight_dims[1] + constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) + return [equality_constraint] + constraints, counter + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) + + +def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): + + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + + nat_constraints) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.tensor) +def tensor_inference_rule(n: Node, symbols, constraints, counter): + """ + If the tensor is a scalar, we will skip it since we + do not support scalars yet. We will add support in the future + if it's needed. For our examples so far, scalars are not needed. + """ + return [], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + + src_var = symbols[n.args[0]] + t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq)] + + [range_check(arg_1, i)] + nat_constraints) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] + + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjunction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + if n.args[0] in symbols: + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK)] + else: + # TODO: we should figure out why there is a key-error here. + return [], counter + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError('Method not yet implemented') + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + elif isinstance(e1, TVar) and isinstance(e2, int): + # then we made the wrong assumption about the argument being a tensor + # so we should fix the assumption + warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') + + new_e1, counter = gen_dvar(counter) + symbols[n.args[0]] = new_e1 + symbols[n.args[0]] + + gt_constraint = BinConstraintD(new_e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule(operator.eq) +def eq_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + eq_tensor, counter = gen_tvar(counter) + symbols[n] = eq_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + else: + raise NotImplementedError('Method not yet implemented') + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + assert isinstance(n.args[1][3], (Node, int)) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) + dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError('Method not yet implemented') + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + dim = arg if isinstance(arg, int) else symbols[arg] + res.append(dim) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError('Not yet implemented') + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj([BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq)]) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj([BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq)]) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + + op_code = None + if n.target == operator.add or n.target == torch.add: + op_code = op_add + elif n.target == operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + else: + raise NotImplementedError('Method not yet implemented') + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError('Addition not yet implemented') + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.functional.layer_norm) +def layer_norm_functional(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) + + +def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) + + +@register_inference_rule("dim") # type: ignore[attr-defined] +def torch_dim_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + my_dim, counter = gen_dvar(counter) + symbols[n] = my_dim + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(my_dim, Dyn, op_eq) + + c1 = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq)]) + c1.append(c_tensor_i) + + return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter + + +@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined] +def torch_linear_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) + constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) + return [equality_constraint] + constraints, counter + + +def linear_constraints(n: Node, in_features, out_features, symbols, counter): + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + + nat_constraints) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, in_features, out_features): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv(my_conv, input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, + module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, 'graph') else graph + + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == 'placeholder': + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + + my_type = n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert 'example_value' in n.meta + my_type = TensorType(n.meta['example_value'].size()) + else: + my_type = Dyn + + c1 = BinConstraintT(my_type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == 'call_function': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'call_module': + + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, + module_instance, + self.symbol_dict, + self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + + elif n.op == 'call_method': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'get_attr': + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = list(t.shape) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == 'output': + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..439e3d6195e654147f5f583b6b13fa9611757372 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -0,0 +1,1040 @@ +# mypy: ignore-errors +import copy +import itertools +from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK +from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ + Transpose +from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool +from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect +from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching +from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq +from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar +from torch.fx.tensor_type import TensorType, Dyn +from typing import Callable, Dict, List + +_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} + + +def register_transformation_rule(call_target): + def register(fn): + if call_target in _TRANSFORMATION_RULES: + raise RuntimeError(f'Transformation rule already registered for {call_target}!') + _TRANSFORMATION_RULES[call_target] = fn + return fn + return register + + +def valid_index(index, dims): + """ + Given a list of dimensions, checks if an index is valid in the list + """ + try: + dims[index] + return T() + except IndexError: + return F() + + +@register_transformation_rule(Transpose) +def transform_transpose(constraint, counter): + """ + Similar to a sequence of two index-selects + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index1 = valid_index(constraint.index1, dims) + is_valid_index2 = valid_index(constraint.index2, dims) + new_dims = copy.deepcopy(dims) + nat_constraints = gen_nat_constraints(dims) + + if is_valid_index1 == T() and is_valid_index2 == T(): + new_dims[constraint.index1] = dims[constraint.index2] + new_dims[constraint.index2] = dims[constraint.index1] + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + return transformed_constraint, counter + + +@register_transformation_rule(IndexSelect) +def transform_index_select(constraint, counter): + """ + The constraints consider the given tensor size, checks if the index is valid + and if so, generates a constraint for replacing the input dimension + with the required dimension + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index = valid_index(constraint.index, dims) + nat_constraints = gen_nat_constraints(dims) + + # if the index is valid then replace the input dimension with the new dimension + # otherwise the dimension will not be replaced and the clause will contain False + if is_valid_index == T(): + new_dims = copy.deepcopy(dims) + new_dims[constraint.index] = constraint.dim_replace + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + + # print(constraints) + return transformed_constraint, counter + + +@register_transformation_rule(GetItem) +def transform_get_item(constraint, counter): + """ + generate an equality of the form: + t = [a1, ..., an] + then generate constraints that check if the given index is valid + given this particular tensor size. + If the index is valid, generate a constraint to get the item + Note that we already handled the Dyn input case in the previous + step. + Args: + constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) + counter: variable tracking + Returns: simplified constraints for GetItem + + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + + is_valid_index = valid_index(constraint.index, dims) + + all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index] + + # if the index is valid, we generate a constraint for getting an item + # otherwise this clause will have been UNSAT due to the wrong index + if is_valid_index == T(): + all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) + + return Conj(all_constraints), counter + +def valid_index_tensor(index, dims): + """ + if the slice instances exceed the length of the dimensions + then this is a type error so we return False + """ + slice_count = 0 + for s in index: + if isinstance(s, slice): + slice_count += 1 + if slice_count > len(dims): + return F() + else: + return T() + +@register_transformation_rule(GetItemTensor) +def transform_get_item_tensor(constraint, counter): + """ + When the index is a tuple, then the output will be a tensor + TODO: we have to check if this is the case for all HF models + + The cases we are covering here are a tuple with one of: + - slice with default argument + - None + + None appends 1 to the input tensor dimensions + so each occurrence of 'None' increases the rank by 1 + + slice with default arguments does not change the rank + """ + assert isinstance(constraint.index_tuple, tuple) + + + # generate a result tensor of the expected size + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + # generate a place-holder list of the right rank + # where "slice" does not contribute to the rank and "None" does + none_c = constraint.index_tuple.count(None) + resulting_tensor_dims = (none_c + len(dims)) * [None] + + dim_index = 0 + for i in range(len(constraint.index_tuple)): + + # append 1 to the right location of the resulting tensor + if constraint.index_tuple[i] is None: + resulting_tensor_dims[i] = 1 + + elif constraint.index_tuple[i] == slice(None, None, None): + pass + + else: + raise NotImplementedError('Method not yet implemented') + + # append the remaining dimensions to the right location + dim_index = 0 + for i in range(len(resulting_tensor_dims)): + if resulting_tensor_dims[i] is None: + resulting_tensor_dims[i] = dims[dim_index] + dim_index += 1 + + # check if the index is valid + is_valid_index = valid_index_tensor(constraint.index_tuple, dims) + + # check if the resulting tensor is within bounds + if len(resulting_tensor_dims) > 4: + return F(), counter + + else: + constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index] + return Conj(constraints), counter + + +@register_transformation_rule(BinConstraintT) +def generate_binconstraint_t(constraint, counter): + """ + Transform binary constraints for tensors + """ + + # precision constraints + if constraint.op == op_precision: + if constraint.lhs == Dyn: + return T(), counter + elif isinstance(constraint.lhs, TensorType): + is_fully_static = all(d != Dyn for d in constraint.lhs.__args__) + if is_fully_static: + return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter + else: + new_dims = [] + + for _ in range(len(constraint.lhs.__args__)): + dim, counter = gen_dvar(counter) + new_dims.append(dim) + + new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for + new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ + [BinConstraintD(1, new_dim, op_leq) for + new_dim in new_dims] + return Conj(new_dim_constraints), counter + + # matching + elif constraint.op == op_matching: + assert isinstance(constraint.rhs, TensorType) + d1 = constraint.rhs.__args__[0] + d2 = constraint.rhs.__args__[1] + d3 = constraint.rhs.__args__[2] + d4 = constraint.rhs.__args__[3] + + conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq)] + return Disj([Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter + + elif constraint.op == op_consistency: + c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) + [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) + + return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter + + elif constraint.op == op_leq: + assert isinstance(constraint.rhs, int) + disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] + for i in range(1, constraint.rhs + 1): + dims = [] + for j in range(1, i + 1): + dim_var, counter = gen_dvar(counter) + dims.append(dim_var) + disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) + return Disj(disj), counter + else: + return constraint, counter + + +@register_transformation_rule(BinConstraintD) +def generate_binconstraint_d(constraint, counter): + """ + Transform binary constraints for dimensions + """ + if constraint.op == op_precision: + if isinstance(constraint.lhs, int): + return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter + elif constraint.lhs == Dyn: + return T(), counter + + elif constraint.op == op_consistency: + return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter + + else: + return constraint, counter + + +@register_transformation_rule(Conj) +def generate_conj(constraint, counter): + """ + Transform conjunctions + """ + new = [] + for c in constraint.conjucts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Conj(new), counter + + +@register_transformation_rule(Disj) +def generate_disj(constraint, counter): + """ + Transform disjunctions + """ + new = [] + for c in constraint.disjuncts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Disj(new), counter + + +@register_transformation_rule(TGreatestUpperBound) +def generate_gub(constraint, counter): + """ + Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound + on dimensions + """ + c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) + + [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) + + return Disj([c1, c2, c3, c4, c5]), counter + + +@register_transformation_rule(DGreatestUpperBound) +def generate_d_gub(constraint, counter): + """ + Transform greatest upper bound for dimensions into equality constraints + """ + c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) + c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + return Disj([c1, c2, c3]), counter + + +@register_transformation_rule(CalcConv) +def generate_calc_conv(constraint, counter): + d, counter = gen_tensor_dims(4, counter) + conv_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the convolution result is a tensor of size 4 + c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) + + # the second dimension of the output is equal to the output channels + c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) + + # the input corresponds to the output in the first dimension of the convolution + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcMaxPool) +def generate_calc_maxpool(constraint, counter): + """ + Transform maxpool constraints + """ + d, counter = gen_tensor_dims(4, counter) + maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the maxpool result is a tensor of size 4 + c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) + + # the input corresponds to the output in the first and second dimension of maxpool + c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcProduct) +def generate_calc_product(constraint, counter): + """ + Transform flatten constraints + """ + start = constraint.start + end = constraint.end + dims = constraint.dims_to_flatten + flattened = constraint.flattened + n = len(constraint.dims_to_flatten) + + # this will be evaluated right here + boundary_check = (0 <= start and start < end and end <= n) + + c_boundary = T() if boundary_check else F() + + lhs = dims[0:start] + rhs = dims[end:] + mid = dims[start:end] + + all_possibilities = generate_all_int_dyn_dim_possibilities(mid) + + all_constraints = [] + + for p in all_possibilities: + p = list(p) + # this tells us there is a dynamic variable + contains_dyn = not all(constraint.op == op_neq for constraint in p) + if contains_dyn: + mid_var = [Dyn] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) + else: + new_var, counter = gen_dvar(counter) + mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) + mid_var = [new_var] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) + + return Conj([Disj(all_constraints), c_boundary]), counter + + +@register_transformation_rule(CanReshape) +def generate_reshape(constraint, counter): + """ + Transform reshape constraints + """ + d, counter = gen_tensor_dims(4, counter) + + d1 = d[0] + d2 = d[1] + d3 = d[2] + d4 = d[3] + + target = constraint.target.__args__ + + is_fully_static = all(d != Dyn for d in target) + + # dynamic tensor + c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) + c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) + c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) + c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) + c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) + + d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) + d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) + + d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) + d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) + + d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + nat_d1 = BinConstraintD(0, d1, op_leq) + nat_d2 = BinConstraintD(0, d2, op_leq) + nat_d3 = BinConstraintD(0, d3, op_leq) + nat_d4 = BinConstraintD(0, d4, op_leq) + + if is_fully_static: + # size 1 tensor + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + BinConstraintD(d1, Prod(target), op_eq)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # size 2 tensor + all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) + + # size 3 tensor + all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) + + # size 4 tensor + all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + # then there must be exactly one occurrence of dyn + else: + new_target = [] + + for n in target: + if n != Dyn: + new_target.append(n) + + # tensor 1 + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + is_dim_div_by_target(new_target, d1)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # tensor 2 + c21 = Disj([d1_eq_dyn, d2_eq_dyn]) + c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) + all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) + + # tensor 3 + c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) + c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) + all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) + + # tensor 4 + c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) + c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) + all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + +@register_transformation_rule(ApplyBroadcasting) +def generate_broadcasting(constraint, counter): + """ + Transform broadcasting constraints + """ + e11, e12 = constraint.res1, constraint.res2 + e1, e2 = constraint.input1, constraint.input2 + + e1_dyn = BinConstraintT(e1, Dyn, op_eq) + e2_dyn = BinConstraintT(e2, Dyn, op_eq) + + # Introduce dimensions + e1_equal_e11 = BinConstraintT(e1, e11, op_eq) + e2_equal_e12 = BinConstraintT(e2, e12, op_eq) + + # dyn possibility + e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) + e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) + + # tensor possibility + # generate dimensions to create tensors of size 1 + final_tensor_1_constraint, _, _, nat_dims_1, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) + + # generate dimensions to create tensors of size 2 + final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ + final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ + final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ + final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj([ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2 + ]) + + return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter + + +def transform_constraint(constraint: Constraint, counter: int): + """ + Transforms a constraint into a simpler constraint. + Ex: precision and consistency are transformed to equality + Args: + constraint: constraint to be transformed + counter: for variable tracking + + Returns: Constraint + + """ + if type(constraint) in _TRANSFORMATION_RULES: + return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) + + else: + return constraint, counter + + + + +def calc_last_two_dims(constraint, d: List[DVar]): + """ + Generates constraints for the last two dimensions of a convolution or a maxpool output + Args: + constraint: CalcConv or CalcMaxPool + d: The list of output dimensions + + Returns: Constraints for calculating the last two dimensions of the output + + """ + + assert isinstance(constraint, (CalcConv, CalcMaxPool)) + + b3 = constraint.matching_constraint[2] + b4 = constraint.matching_constraint[3] + + b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) + b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) + + d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) + d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) + + # transform parameters into tuples incase they are not already + padding = (constraint.padding, constraint.padding) \ + if isinstance(constraint.padding, int) else constraint.padding + kernel = (constraint.kernel, constraint.kernel) \ + if isinstance(constraint.kernel, int) else constraint.kernel + stride = (constraint.stride, constraint.stride) \ + if isinstance(constraint.stride, int) else constraint.stride + dilation = (constraint.dilation, constraint.dilation) \ + if isinstance(constraint.dilation, int) else constraint.dilation + + f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) + f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) + f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) + f4 = BinConstraintD(f3, 1, op_add) + + c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) + + f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) + f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) + f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) + f44 = BinConstraintD(f33, 1, op_add) + + c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) + + return c4, c5 + + +def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): + """ + Generate all possibilities of being equal or not equal to dyn for my_list + Args: + my_list: List of tensor dimensions + + Returns: A list of a list of constraints. Each list of constraints corresponds to + one possibility about the values of the dimension variables + """ + # generate all possibilities of being equal or not equal to dyn for my_list + eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] + neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] + d_possibilities = [] + + for i in zip(eq_possibilities, neq_possibilities): + d_possibilities.append(list(i)) + all_possibilities = list(itertools.product(*d_possibilities)) + return all_possibilities + + +def is_target_div_by_dim(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the target dimensions are divisible by the input dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) + + +def is_dim_div_by_target(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the input dimensions is divisible by the target dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) + + +def gen_all_reshape_possibilities(list_of_dims, target): + """ + Consider all possibilities what the input dimensions could be (number or dynamic) + Then generate the appropriate constraints using multiplication or mod depending on the possibility + The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn + for the input. Target is fixed because at most one dimension could be dyn. + We have different cases for this. + + Args: + list_of_dims: The input list of dimensions + target: The tensor we want to reshape to + + Returns: A disjunction of transformed reshape constraints + + """ + all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) + + all_constraints = [] + + for p in all_possibilities: + to_multiply = [] + + p = list(p) + + for constraint in p: + assert isinstance(constraint, BinConstraintD) + if constraint.op == op_neq: + to_multiply.append(constraint.lhs) + + if not to_multiply: + all_constraints.append(Conj(p)) + + elif len(to_multiply) < len(list_of_dims): + all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) + else: + all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), + Prod(target), op_eq)])) + + return Disj(all_constraints) + + +def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): + """ + Apply broadcasting to the 'index' dimension of tensor_input1. + Args: + tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 + tensor_input2: represents the second input + res1: broadcasted result 1 + res2: broadcasted result 2 + index: the index to broadcast + padding: If padding was used, then tensor_input1[index] does not exist + + Returns: + + """ + if tensor_input1[index] is None: + assert padding + + + if not padding: + # then the inputs are the same length so they all have dimensions at "index" + return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + else: + # we don't set the input dimension to 1, since it doesn't exist. + return Conj([BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + +def apply_padding(e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], + counter: int): + """ + We are considering the possibility where one input has less dimensions than + another input, so we apply padding to the broadcasted results + + Args: + e1_var: Variable representing the first input where padding will be + e11: constraint of the form e11 = Tensortype[d1, ..., dn] + e2: constraint of the form e2 = Tensortype[d1, ..., dn] + e12: constraint of the form e11 = Tensortype[d1, ..., dn] + d2: Tensor variables for the second input + d11: Tensor variables for the broadcasted first input + d12: Tensor variables for the broadcasted second input + counter: variable tracking + + Returns: A new constraint whose goal is to apply padding to the broadcasted result + + """ + + res = [] + + # pad the shorter input with None so we can pass it to the broadcasting helper function + for i in range(1, len(d2)): + + d1, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) + + e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) + + simulate_padding = [None] * (len(d2) - i) + + assert len(simulate_padding + d1) == len(d2) + + broadcast_padding = [] + + # for every padding size, we also consider broadcasting + for j in range(len(d2) - i): + broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) + + # we consider the possibilities for broadcasting for every dimension. Since we already + # padded d1, we do not consider it while broadcasting + all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, + d2[(len(d2) - i):], + d11[(len(d2) - i):], + d12[(len(d2) - i):]) + # combine all constraints into a conjunction + c = Conj([e1, e11, e2, e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints + ]) + res.append(c) + + return Disj(res), counter + + +def no_broadcast_dim_with_index(d1: List[DVar], + d2: List[DVar], + d3: List[DVar], + d4: List[DVar], + i: int): + """ + Args: + d1: input 1 + d2: input 2 + d3: simulated broadcasting for input 1 + d4: simulated broadcasting for input 2 + i: the rank of the resulting tensor addition + + Returns: Constraints for when no broadcasting occurs + """ + return Conj([ + Disj([ + Conj([BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq)]), + + Conj([BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq)])]), + + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq)]) + + + +def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): + """ + Generate lists of DVar to represent tensor dimensions + Args: + num_tensors: the required number of tensors + dim_size: the number of dimensions for each tensor + counter: variable tracking + + Returns: A list of a list of tensor dimensions + + """ + res = [] + + for _ in range(num_tensors): + dims, counter = gen_tensor_dims(dim_size, counter) + res.append(dims) + + return res, counter + + +def create_equality_constraints_for_broadcasting(e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: List[DVar], + d2: List[DVar], + d11: List[DVar], + d12: List[DVar]): + """ + Create equality constraints for when no broadcasting occurs + Args: + e1: Input 1 + e2: Input 2 + e11: Broadcasted input 1 + e12: Broadcasted input 2 + d1: Variables that store dimensions for e1 + d2: Variables that store dimensions for e2 + d11: Variables that store dimensions for e11 + d12: Variables that store dimensions for e22 + + Returns: Four equality constraints + + """ + + e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) + e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) + e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) + e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) + return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] + + +def gen_consistency_constraints(constraint: Constraint, counter: int): + """ + Args: + constraint: Consistency constraint on tensors + counter: for variable tracking + + Returns: Equality and consistency constraints on dimensions + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + + [BinConstraintD(d1, d2, op_consistency) for + d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) + + all_constraints.append(c_tensor_i) + + return all_constraints, counter + + +def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): + """ + Args: + constraint: Greatest upper bound on tensors + counter: variable tracking + + Returns: A set of equality constraints and DGreatestUpperBound constraints + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + c = [] + dims1, counter = gen_tensor_dims(i, counter) + c1tensor = TensorType(dims1) + + dims2, counter = gen_tensor_dims(i, counter) + c2tensor = TensorType(dims2) + + dims3, counter = gen_tensor_dims(i, counter) + c3tensor = TensorType(dims3) + + c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq)] + \ + gen_nat_constraints(dims1 + dims2 + dims3) + + assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + for i in range(len(c3tensor.__args__)): + c.append(DGreatestUpperBound(c3tensor.__args__[i], + c1tensor.__args__[i], + c2tensor.__args__[i])) + + all_constraints.append(Conj(c)) + return all_constraints, counter + + +def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): + """ + Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. + We look at all combinations for all dimensions in d1 and d2 + Args: + d1: input1 dimensions + d2: input2 dimensions + d11: broadcasted input1 dimensions + d12: broadcasted input2 dimensions + + Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions + + """ + + size = len(d1) + + res2 = [] + + for i in range(size): + t1 = broadcast_dim(d1, d2, d11, d12, i) + t2 = broadcast_dim(d2, d1, d12, d11, i) + t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) + + res2.append(Disj([t1, t2, t3])) + + return Conj(res2) + + +def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): + """ + Simulates broadcasting on e1 and e2 and returns the results + respectively in e11 and e12. Because of gradual types, + e1 and e2 may not be equal. Similarly, e11 and e12 may not + be equal. e11 and e12 should be guaranteed to be consistent + as they represent the shapes of the tensors to be added after + broadcasting. + Args: + e1: TVar representing the type of input 1 + e2: TVar representing the type of input 2 + e11: TVar representing the representing broadcasted input 1 + e12: TVar representing the representing broadcasted input 2 + i: The rank of the resulting type of addition + counter: for variable tracking + + Returns: Simplified broadcasting constraints + + """ + dims, counter = gen_lists_of_dims(4, i, counter) + [d1, d2, d3, d4] = dims + nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) + + initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, + d1, d2, d3, d4) + + [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints + + # without padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) + + # with padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_padding_arg1, counter = \ + apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) + + final_tensor_constraint_padding_arg2, counter = \ + apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) + + return final_tensor_constraint_no_padding, \ + final_tensor_constraint_padding_arg1, \ + final_tensor_constraint_padding_arg2, nat_dims_i, counter diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py new file mode 100644 index 0000000000000000000000000000000000000000..15af0241ec5b083d5e61847b611f1d5c66c3e02d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -0,0 +1,348 @@ +from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr +from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar +from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt +from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod +from torch.fx.tensor_type import TensorType, Dyn + +try: + import z3 # type: ignore[import] + from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D + HAS_Z3 = True + + def transform_to_z3(constraint, counter, dimension_dict): + if isinstance(constraint, Conj): + conjuncts = [] + for c in constraint.conjucts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + conjuncts.append(new_c) + return z3.And(conjuncts), counter + + elif isinstance(constraint, Disj): + disjuncts = [] + for c in constraint.disjuncts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + disjuncts.append(new_c) + return z3.Or(disjuncts), counter + + elif isinstance(constraint, T): + return True, counter + + elif isinstance(constraint, F): + return False, counter + + elif isinstance(constraint, BinConstraintT): + if constraint.op == op_eq: + lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + return (lhs == rhs), counter + + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(constraint, BinConstraintD): + if constraint.op == op_eq: + + if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): + transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) + transformed_lhs = z3.Bool(constraint.lhs.c) + return transformed_lhs == transformed_rhs, counter + + elif is_dim(constraint.lhs) and is_dim(constraint.rhs): + # with dimension transformations we consider the encoding + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + else: + # then we have an algebraic expression which means that we disregard the + # first element of the encoding + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + # The assumption here is that the LHS and RHS must be dimensions + elif constraint.op == op_neq: + assert is_dim(constraint.lhs) + assert is_dim(constraint.rhs) + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + if constraint.rhs == Dyn or constraint.lhs == Dyn: + if constraint.rhs == Dyn: + return lhs.arg(0) == 1, counter + elif constraint.lhs == Dyn: + return rhs.arg(0) == 1, counter + + # if one of the instances is a number + elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): + if isinstance(constraint.lhs, int): + return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + elif isinstance(constraint.rhs, int): + return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + else: + return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter + + + elif constraint.op == op_leq: + # if the dimensions are not dyn, this will come into effect + # there would have been another constraint specifying if a given dimension + # is dyn or not + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs <= rhs, counter + + elif constraint.op == op_gt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs > rhs, counter + + elif constraint.op == op_lt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs < rhs, counter + + else: + raise NotImplementedError('operation not yet implemented') + + else: + raise NotImplementedError('Operation not yet implemented') + + + def transform_var(tensor, counter, dimension_dict): + """ + Transforms tensor variables to a format understood by z3 + Args: + tensor: Tensor variable or a tensor type potentially with variable dimensions + Returns: Transformed variable to a z3 format + + """ + if isinstance(tensor, TensorType): + res = [] + for t in tensor.__args__: + transformed, counter = transform_dimension(t, counter, dimension_dict) + res.append(transformed) + + assert len(res) <= 4 + if len(tensor.__args__) == 1: + return tensor_type.tensor1(res[0]), counter + elif len(tensor.__args__) == 2: + return tensor_type.tensor2(res[0], res[1]), counter + elif len(tensor.__args__) == 3: + return tensor_type.tensor3(res[0], res[1], res[2]), counter + elif len(tensor.__args__) == 4: + return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + + elif tensor == Dyn: + return z3_dyn, counter + + elif isinstance(tensor, TVar): + return z3.Const(tensor.tvar, tensor_type), counter + + def transform_dimension(dimension, counter, dimension_dict): + """ + Takes a dimension variable or a number and transforms it to a tuple + according to our scheme + Args: + dimension: The dimension to be transformed + counter: variable tracking + + Returns: tuple and the current counter + + """ + if dimension == Dyn: + counter += 1 + return D(0, z3.Int(counter)), counter + elif isinstance(dimension, int): + return D(1, dimension), counter + elif isinstance(dimension, DVar): + if dimension.c in dimension_dict: + return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter + else: + counter += 1 + dimension_dict[dimension.c] = counter + return D(z3.Int(counter), z3.Int(dimension.c)), counter + + + def transform_algebraic_expression(expr, counter, dimension_dict): + """ + Transforms an algebraic expression to z3 format + Args: + expr: An expression is either a dimension variable or an algebraic-expression + + + Returns: the transformed expression + + """ + assert is_algebraic_expression(expr) or is_dim(expr) + + if is_dim(expr): + transformed, counter = transform_dimension(expr, counter, dimension_dict) + return transformed.arg(1), counter + + elif isinstance(expr, Prod): + + dims = [] + for dim in expr.products: + assert is_dim(dim) + d, counter = transform_dimension(dim, counter, dimension_dict) + dims.append(d.arg(1)) + return z3.Product(dims), counter + + elif is_algebraic_expression(expr): + + lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) + + if expr.op == op_sub: + c = lhs - rhs + + elif expr.op == op_add: + c = lhs + rhs + + elif expr.op == op_div: + c = lhs / rhs + + elif expr.op == op_mul: + c = lhs * rhs + + elif expr.op == op_mod: + c = lhs % rhs + + else: + raise NotImplementedError('operation not yet implemented') + + return c, counter + + else: + raise RuntimeError + + + def transform_all_constraints(traced, counter=0): + """ + Given a trace, generates constraints and transforms them to z3 format + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(counter) + + # print(new_constraints.conjucts[0]) + # print(*new_constraints.conjucts, sep='\n') + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + # print(new_constraints) + # print(new_constraints.conjucts) + # new_constraints.conjucts = new_constraints.conjucts[:-1] + # print(*new_constraints.conjucts, sep='\n') + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + # print(transformed) + return transformed + + def iterate_till_fixed_point(constraints, counter): + """ + Transform constraints till reaching a fixed point + """ + old_c = None + while old_c != constraints: + old_c = constraints + constraints, counter = transform_constraint(constraints, counter) + return constraints, counter + + def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + """ + Takes a node and a graph and generates two sets of constraints. + One set constraints the node's constraints and another set + constraints the negation of the node's constraints + Args: + tracer_root: the root for getting the module instances + graph: the graph so far in the tracing process + node: node that represents a conditional + counter: variable tracking + + Returns: Two sets of constraints. One with a conjunction with the + the conditional constraint and the other with a conjunction with + its negation. + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(tracer_root, graph) + new_constraints, counter = generator.generate_constraints(counter) + + condition_constraint = new_constraints.conjucts[-1] + + # we know the constraint is a conjunction where the last constraint is about the conditional + # so remove the last constraint + new_constraints.conjucts = new_constraints.conjucts[:-1] + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + + + # since the function returns a list of one element, we get the first element + # we are only interested in the RHS in this case because the LHS just stores + # the result + + # we make sure the constraint is of the form: + # c = b where b is a boolean expression + # and we consider b (constraint.rhs) for transformation + assert isinstance(condition_constraint.lhs, BVar) + assert is_bool_expr(condition_constraint.rhs) + condition_constraint_rhs = condition_constraint.rhs + + # transform the condition constraint + condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + + transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) + + negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) + + return z3.And([transformed, transformed_condition_constraint]), \ + z3.And([transformed, negation_transformed_condition_constraint]) + + + def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): + """ + Given an IR and a node representing a conditional, evaluate the conditional + and its negation + Args: + tracer_root: Tracer root for module instances + node: The node to be evaluated + + Returns: the results of evaluating the condition and the negation with + the rest of the constraints + + """ + + transformed_positive, transformed_negative = \ + transform_all_constraints_trace_time(tracer_root, graph, node, counter) + + s = z3.Solver() + s.add(transformed_positive) + if user_constraints is not None: + s.add(user_constraints) + condition = s.check() + + s = z3.Solver() + s.add(transformed_negative) + if user_constraints is not None: + s.add(user_constraints) + negation = s.check() + return condition, negation + +except ImportError: + HAS_Z3 = False diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7f13991ea217f4fa3d2208bfe99421e0224397 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..509cbf0951be1bce6dc7b26340e89f169f0c50df Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e82cba7bab1be29c2432ea151f69b09c6f8f52 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd00b6665f8d69b702c68a1e5ac1affe40505a70 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5ec2ed63152e240ccb94935c96b25ad8b66093 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py @@ -0,0 +1,125 @@ +from collections import OrderedDict + +__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +def expand_tuples(L): + """ + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + >>> # Closely follows the wikipedia page [2] + >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + >>> # Communications of the ACM + >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = OrderedDict((k, set(val)) + for k, val in incoming_edges.items()) + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = OrderedDict() # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, tuple()) + (key, ) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """ Group a collection by a key function + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + See Also: + ``countby`` + """ + + d = OrderedDict() # type: ignore[var-annotated] + for item in seq: + key = func(item) + if key not in d: + d[key] = list() + d[key].append(item) + return d + + +def typename(type): + """Get the name of `type`. + Parameters + ---------- + type : Union[Type, Tuple[Type]] + Returns + ------- + str + The name of `type` or a tuple of the names of the types in `type`. + Examples + -------- + >>> typename(int) + 'int' + >>> typename((int, float)) + '(int, float)' + """ + try: + return type.__name__ + except AttributeError: + if len(type) == 1: + return typename(*type) + return f"({', '.join(map(typename, type))})" diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46b7043dedda733833089640f8857cde8f1547c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c71781f6700ad5e98594e15b9d21052c31269c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..0399cef526205f8f82a0c53555bc16fdab67a550 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py @@ -0,0 +1,44 @@ +import operator + +import torch + + +def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: + """ + Annotate the type of getitem nodes, inferred from the type of sequence node. + If sequence node is not annotated with a type, do nothing. + Currently support getitem nodes from Tuple, List, and NamedTuple sequence node. + + This is helpful since annotations on local names within function are lost during FX transforms. + Adding back known type annotation for getitem nodes to improve jit scriptability. + + Args: + graph (Graph): The graph to be annotated + """ + for node in graph.nodes: + if node.target == operator.getitem: + sequence_node, index_node = node.args + if not sequence_node.type: + continue + # container types + if hasattr(sequence_node.type, "_name"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type._name == "Tuple": + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type._name == "List": + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # NamedTuple type + elif hasattr(sequence_node.type, "__annotations__"): + if sequence_node.type == torch.Tensor: + continue + sequence_node_field_types = sequence_node.type.__annotations__ + field_name = sequence_node.type._fields[index_node] + node.type = sequence_node_field_types[field_name] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a18b8bfd1e2303863bb50a9ef04b688050cb7961 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..dc95a70a22a7da599880d962b40c6a0a25aa5634 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,112 @@ +from typing import Dict, Tuple, Any + +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + +from torch.fx import GraphModule, Graph +from torch.fx import Node + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950 + +inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph + hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph + token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substitution happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..a31953ca6e7917755f12af07784777c81438fdb6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py @@ -0,0 +1,73 @@ +from typing import Optional + +import torch.fx +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake +from torch.fx.node import map_aggregate + +__all__ = ['FakeTensorProp'] + +@compatibility(is_backward_compatible=False) +class FakeTensorProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record a fake tensor representing + the metadata for the node. Unlike ShapeProp, (1) this propagation + is cheap--it does the propagation with meta tensors which do not actually + store data, and (2) the fake tensors have much more fine grained information, + e.g., they have accurate alias information that can be consulted by looking + at the storages. + + Args: + module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. + """ + def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode + + def run_node(self, n: Node): + import sympy + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + result = super().run_node(n) + sym = None + if ( + 'val' in n.meta and + isinstance(v := n.meta['val'], torch.SymInt) and + isinstance(v.node.expr, sympy.Symbol) and free_unbacked_symbols(v) + ): + sym = v + + def extract_val(obj): + if isinstance(obj, FakeTensor): + return snapshot_fake(obj) + elif isinstance(obj, torch.Tensor): + # TODO: How is it possible that we get a non fake tensor? We + # should be running under the mode... + return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True)) + elif isinstance(obj, py_sym_types): + return obj + else: + return None + + meta = map_aggregate(result, extract_val) + if meta is not None: + n.meta['val'] = meta + if sym is not None: + torch._check(meta == v) + return result + + def propagate(self, *args): + fake_args = [ + self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a + for a in args + ] + return self.propagate_dont_convert_inputs(*fake_args) + + def propagate_dont_convert_inputs(self, *args): + with self._mode: + return super().run(*args) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfb430245257391718d6ec22bfa858655d965fd --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py @@ -0,0 +1,421 @@ + +import hashlib +import torch +import torch.fx +from typing import Any, Dict, Optional, TYPE_CHECKING +from torch.fx.node import _get_qualified_name, _format_arg +from torch.fx.graph import _parse_stack_trace +from torch.fx.passes.shape_prop import TensorMetadata +from torch.fx._compatibility import compatibility +from itertools import chain + +__all__ = ['FxGraphDrawer'] +try: + import pydot + HAS_PYDOT = True +except ImportError: + HAS_PYDOT = False + +_COLOR_MAP = { + "placeholder": '"AliceBlue"', + "call_module": "LemonChiffon1", + "get_param": "Yellow2", + "get_attr": "LightGrey", + "output": "PowderBlue", +} + +_HASH_COLOR_MAP = [ + "CadetBlue1", + "Coral", + "DarkOliveGreen1", + "DarkSeaGreen1", + "GhostWhite", + "Khaki1", + "LavenderBlush1", + "LightSkyBlue", + "MistyRose1", + "MistyRose2", + "PaleTurquoise2", + "PeachPuff1", + "Salmon", + "Thistle1", + "Thistle3", + "Wheat1", +] + +_WEIGHT_TEMPLATE = { + "fillcolor": "Salmon", + "style": '"filled,rounded"', + "fontcolor": "#000000", +} + +if HAS_PYDOT: + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + """ + Visualize a torch.fx.Graph with graphviz + Basic usage: + g = FxGraphDrawer(symbolic_traced, "resnet18") + g.get_dot_graph().write_svg("a.svg") + """ + + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + ): + self._name = name + self.dot_graph_shape = ( + dot_graph_shape if dot_graph_shape is not None else "record" + ) + _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape + + self._dot_graphs = { + name: self._to_dot( + graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace + ) + } + + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + + leaf_node = self._get_leaf_node(graph_module, node) + + if not isinstance(leaf_node, torch.fx.GraphModule): + continue + + + self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( + leaf_node, + f"{name}_{node.target}", + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + + def get_dot_graph(self, submod_name=None) -> pydot.Dot: + """ + Visualize a torch.fx.Graph with graphviz + Example: + >>> # xdoctest: +REQUIRES(module:pydot) + >>> # define module + >>> class MyModule(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.linear = torch.nn.Linear(4, 5) + >>> def forward(self, x): + >>> return self.linear(x).clamp(min=0.0, max=1.0) + >>> module = MyModule() + >>> # trace the module + >>> symbolic_traced = torch.fx.symbolic_trace(module) + >>> # setup output file + >>> import ubelt as ub + >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() + >>> fpath = dpath / 'linear.svg' + >>> # draw the graph + >>> g = FxGraphDrawer(symbolic_traced, "linear") + >>> g.get_dot_graph().write_svg(fpath) + """ + if submod_name is None: + return self.get_main_dot_graph() + else: + return self.get_submod_dot_graph(submod_name) + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] + + def get_submod_dot_graph(self, submod_name) -> pydot.Dot: + return self._dot_graphs[f"{self._name}_{submod_name}"] + + def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: + return self._dot_graphs + + def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: + + template = { + "shape": self.dot_graph_shape, + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + # Use a random color for each node; based on its name so it's stable. + target_name = node._pretty_print_target(node.target) + target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) + template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] + return template + + def _get_leaf_node( + self, module: torch.nn.Module, node: torch.fx.Node + ) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError( + str(py_obj) + " does not have attribute " + atom + "!" + ) + py_obj = getattr(py_obj, atom) + return py_obj + + def _typename(self, target: Any) -> str: + if isinstance(target, torch.nn.Module): + ret = torch.typename(target) + elif isinstance(target, str): + ret = target + else: + ret = _get_qualified_name(target) + + # Escape "{" and "}" to prevent dot files like: + # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc + # which triggers `Error: bad label format (...)` from dot + return ret.replace("{", r"\{").replace("}", r"\}") + + # shorten path to avoid drawing long boxes + # for full path = '/home/weif/pytorch/test.py' + # return short path = 'pytorch/test.py' + def _shorten_file_name( + self, + full_file_name: str, + truncate_to_last_n: int = 2, + ): + splits = full_file_name.split('/') + if len(splits) >= truncate_to_last_n: + return '/'.join(splits[-truncate_to_last_n:]) + return full_file_name + + + def _get_node_label( + self, + module: torch.fx.GraphModule, + node: torch.fx.Node, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> str: + def _get_str_for_args_kwargs(arg): + if isinstance(arg, tuple): + prefix, suffix = r"|args=(\l", r",\n)\l" + arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] + elif isinstance(arg, dict): + prefix, suffix = r"|kwargs={\l", r",\n}\l" + arg_strs_list = [ + f"{k}: {_format_arg(v, max_list_len=8)}" + for k, v in arg.items() + ] + else: # Fall back to nothing in unexpected case. + return "" + + # Strip out node names if requested. + if skip_node_names_in_args: + arg_strs_list = [a for a in arg_strs_list if "%" not in a] + if len(arg_strs_list) == 0: + return "" + arg_strs = prefix + r",\n".join(arg_strs_list) + suffix + if len(arg_strs_list) == 1: + arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") + return arg_strs.replace("{", r"\{").replace("}", r"\}") + + + label = "{" + f"name=%{node.name}|op_code={node.op}\n" + + if node.op == "call_module": + leaf_module = self._get_leaf_node(module, node) + label += r"\n" + self._typename(leaf_module) + r"\n|" + extra = "" + if hasattr(leaf_module, "__constants__"): + extra = r"\n".join( + [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] + ) + label += extra + r"\n" + else: + label += f"|target={self._typename(node.target)}" + r"\n" + if len(node.args) > 0: + label += _get_str_for_args_kwargs(node.args) + if len(node.kwargs) > 0: + label += _get_str_for_args_kwargs(node.kwargs) + label += f"|num_users={len(node.users)}" + r"\n" + + tensor_meta = node.meta.get('tensor_meta') + label += self._tensor_meta_to_label(tensor_meta) + + # for original fx graph + # print buf=buf0, n_origin=6 + buf_meta = node.meta.get('buf_meta', None) + if buf_meta is not None: + label += f"|buf={buf_meta.name}" + r"\n" + label += f"|n_origin={buf_meta.n_origin}" + r"\n" + + # for original fx graph + # print file:lineno code + if parse_stack_trace and node.stack_trace is not None: + parsed_stack_trace = _parse_stack_trace(node.stack_trace) + fname = self._shorten_file_name(parsed_stack_trace.file) + label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" + + + return label + "}" + + def _tensor_meta_to_label(self, tm) -> str: + if tm is None: + return "" + elif isinstance(tm, TensorMetadata): + return self._stringify_tensor_meta(tm) + elif isinstance(tm, list): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + elif isinstance(tm, dict): + result = "" + for v in tm.values(): + result += self._tensor_meta_to_label(v) + return result + elif isinstance(tm, tuple): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + else: + raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") + + def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: + result = "" + if not hasattr(tm, "dtype"): + print("tm", tm) + result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" + result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" + result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" + result += "|" + "stride" + "=" + str(tm.stride) + r"\n" + if tm.is_quantized: + assert tm.qparams is not None + assert "qscheme" in tm.qparams + qscheme = tm.qparams["qscheme"] + if qscheme in { + torch.per_tensor_affine, + torch.per_tensor_symmetric, + }: + result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + }: + result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" + else: + raise RuntimeError(f"Unsupported qscheme: {qscheme}") + result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" + return result + + def _get_tensor_label(self, t: torch.Tensor) -> str: + return str(t.dtype) + str(list(t.shape)) + r"\n" + + # when parse_stack_trace=True + # print file:lineno code + def _to_dot( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool, + ignore_parameters_and_buffers: bool, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> pydot.Dot: + """ + Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. + If ignore_parameters_and_buffers is True, the parameters and buffers + created with the module will not be added as nodes and edges. + """ + + # "TB" means top-to-bottom rank direction in layout + dot_graph = pydot.Dot(name, rankdir="TB") + + + buf_name_to_subgraph = {} + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + style = self._get_node_style(node) + dot_node = pydot.Node( + node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style + ) + + current_graph = dot_graph + + buf_meta = node.meta.get('buf_meta', None) + if buf_meta is not None and buf_meta.n_origin > 1: + buf_name = buf_meta.name + if buf_name not in buf_name_to_subgraph: + buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) + current_graph = buf_name_to_subgraph.get(buf_name) + + current_graph.add_node(dot_node) + + def get_module_params_or_buffers(): + for pname, ptensor in chain( + leaf_module.named_parameters(), leaf_module.named_buffers() + ): + pname1 = node.name + "." + pname + label1 = ( + pname1 + "|op_code=get_" + "parameter" + if isinstance(ptensor, torch.nn.Parameter) + else "buffer" + r"\l" + ) + dot_w_node = pydot.Node( + pname1, + label="{" + label1 + self._get_tensor_label(ptensor) + "}", + **_WEIGHT_TEMPLATE, + ) + dot_graph.add_node(dot_w_node) + dot_graph.add_edge(pydot.Edge(pname1, node.name)) + + if node.op == "call_module": + leaf_module = self._get_leaf_node(graph_module, node) + + if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): + get_module_params_or_buffers() + + for subgraph in buf_name_to_subgraph.values(): + subgraph.set('color', 'royalblue') + subgraph.set('penwidth', '2') + dot_graph.add_subgraph(subgraph) + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + for user in node.users: + dot_graph.add_edge(pydot.Edge(node.name, user.name)) + + return dot_graph + +else: + if not TYPE_CHECKING: + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + ): + raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' + 'pydot through your favorite Python package manager.') diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e4487af5dfe8703f7249a0327ed7445a948d3e4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a861be0f43531819fcc17ecdf817bce8194950 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,329 @@ +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions +import collections +import itertools +import logging + +from copy import copy +from typing import Dict, Iterable, List, Optional, Sequence, Set + +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, _get_qualified_name +from torch.fx.passes.operator_support import OperatorSupportBase + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +class Partition: + def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): + self.id = id + self.nodes: Set[Node] = set(nodes) if nodes is not None else set() + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node): + self.nodes.add(node) + + def remove_node(self, node: Node): + self.nodes.remove(node) + + def size(self): + return len(self.nodes) + +class _DependencyViewer: + def __init__(self, graph_module: GraphModule): + self.upstreams = collections.defaultdict(set) + self.downstreams = collections.defaultdict(set) + + for node in graph_module.graph.nodes: + for input_node in node.all_input_nodes: + # add input_node and input_node's upstream dependency + self.upstreams[node].add(input_node) + self.upstreams[node].update(self.upstreams[input_node]) + + for node in reversed(graph_module.graph.nodes): + for output_node in node.users: + # add output_node and output_node's downstream dependency + self.downstreams[node].add(output_node) + self.downstreams[node].update(self.downstreams[output_node]) + + def downstreams_of(self, node: Node) -> Set[Node]: + return self.downstreams[node] + + def upstreams_of(self, node: Node) -> Set[Node]: + return self.upstreams[node] + +class CapabilityBasedPartitioner: + + def __init__(self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] + self.allowed_single_node_partition_ops = ( + allowed_single_node_partition_ops + if allowed_single_node_partition_ops is not None + else [] + ) + self.dependency_viewer = _DependencyViewer(graph_module) + + def __is_node_supported(self, node: Node) -> bool: + return ( + self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node) + ) + + def propose_partitions(self) -> List[Partition]: + # partition_map is a mapping from partition id to a set of partition id's. + # The value set contains all the partition ids that can be reached by doing a + # DFS starting from the partition id in the key. + partition_map : Dict[int, Set] = collections.defaultdict(set) + + # assumptions: nodes in candidate list is sorted in topological order + assignment: Dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition + new_partition_id = itertools.count() + + # try to merge partition other_id into partition self_id + # merge only happens if the end graph doesn't contain cyclic dependency + # returns `True` when merge happens, `False` otherwise. + def maybe_merge_partition(self_id: int, other_id: int): + # merged_nodes is the union of nodes in two partition to-be-merged + merged_nodes = copy(partitions_by_id[self_id].nodes) + merged_nodes.update(partitions_by_id[other_id].nodes) + + def dfs_iter_find_cycle(all_user_nodes: List[Node]): + for user_node in all_user_nodes: + visited_partition_ids = set() + + for path_node in self.dependency_viewer.downstreams_of(user_node): + # If any of the nodes in the dfs path of this node are in the merged_nodes + # list then there is a cycle in the graph. + if path_node in merged_nodes: + return True + + # If any of the nodes in the dfs path of this node are in the assignment + # map then we have to make sure that the partitions that these nodes belong + # to do not form a cycle with the current partitions being merged. This means + # iterating through all the nodes in all the parititons that are traversed in + # the dfs path and checking if they are in the merged_nodes list. + if path_node in assignment: + partition_id = assignment[path_node] + # If the partition id has already been visited then we know that it doesn't + # form a cycle with the current partitions being merged. + if partition_id in visited_partition_ids: + continue + p_map = partition_map[partition_id] + if self_id in p_map or other_id in p_map: + return True + + visited_partition_ids.add(partition_id) + + return False + + # check if merge would create cyclic dependency. + all_user_nodes = [] + for node in merged_nodes: + for user_node in node.users: + if user_node not in merged_nodes: + all_user_nodes.append(user_node) + + if dfs_iter_find_cycle(all_user_nodes): + # return false indicating cyclic dependency found and + # merge is aborted + return False + + # no cyclic dependency found, move forward with the merge + # updating partition nodes + partitions_by_id[self_id].nodes = merged_nodes + # updating assignment map + for node in partitions_by_id[other_id].nodes: + assignment[node] = self_id + # delete other partition + del partitions_by_id[other_id] + + partition_map[self_id] = partition_map[self_id].union(partition_map[other_id]) + del partition_map[other_id] + + return True + + def merge_single_node(node: Node, id: Optional[int]): + def _update_partition_map(node: Node, id: int): + # Iterate through all the downstream nodes of this node and update the partition map + # to indicate that there is a path from the partition id of this node to the target + # partition id. + downstream_nodes = self.dependency_viewer.downstreams_of(node) + for curr_node in downstream_nodes: + target_id = assignment.get(curr_node, None) + if target_id is not None: + partition_map[id].add(target_id) + + # Iterate through all the upstream nodes of this node and update the partition map + # to indicate that there is a path from the partition id of the upstream node to the + # current node's partition id. + upstream_nodes = self.dependency_viewer.upstreams_of(node) + for curr_node in upstream_nodes: + source_id = assignment.get(curr_node, None) + if source_id is not None: + partition_map[source_id].add(id) + + if node in assignment: + partitions_by_id[assignment[node]].remove_node(node) + + if id is None: + assignment.pop(node) + elif id not in partitions_by_id: + assignment[node] = id + partitions_by_id[id] = Partition(id=id, nodes=[node]) + _update_partition_map(node, id) + else: + assignment[node] = id + partitions_by_id[id].add_node(node) + _update_partition_map(node, id) + + logger.debug("Proposing partitions...") + + for node in reversed(self.graph_module.graph.nodes): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + merge_candidates: Dict[int, None] = {} + + # Note a limited horizontal fusion is enabled: + # when `node` is not supported, the code below attempts to fuse consumer of `node`. + # + # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut + # the fusion by adding an `else` block here to skip horizontal fusion. + if self.__is_node_supported(node) and node not in assignment: + partition_id = next(new_partition_id) + merge_single_node(node, partition_id) + merge_candidates[partition_id] = None + + # merge all possible partitions + for node in assignment: + merge_candidates[assignment[node]] = None + + merge_candidates_list = list(merge_candidates.keys()) + if len(merge_candidates_list) > 1: + self_id = merge_candidates_list[0] + for other_id in merge_candidates_list[1:]: + # note: merge partition `other_id` into partition `self_id` if + # it doesn't create cyclic dependency in the graph, otherwise, + # this is a no-op + maybe_merge_partition(self_id, other_id) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: Dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if user.op != "call_function" or \ + _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node, None) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user, None) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id # type: ignore[assignment] + for node, id in nodes_reassignment.items(): + merge_single_node(node, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + partitions_to_remove: List[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function": + assert callable(node.target) + if _get_qualified_name(node.target) not in non_compute_ops: + compute_node_count += 1 + if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops: + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logger.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes]) + + return list(partitions_by_id.values()) + + def fuse_partitions(self, partitions: List[Partition]) -> GraphModule: + logger.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] + return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions]) + + # remove non-compute-ops that sits at the boundary of a partition. + def remove_bookend_non_compute_ops(self, partitions: List[Partition]): + non_compute_ops = set(self.non_compute_ops) + + def is_non_compute_node(node: Node): + return node.op == "call_function" and \ + _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + + # cache transparent nodes + transparent_input_nodes: Dict[Node, bool] = {} + transparent_output_nodes: Dict[Node, bool] = {} + + def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): + if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + return True + if node in transparent_input_nodes: + return transparent_input_nodes[node] + if is_non_compute_node(node): + for input_n in node.all_input_nodes: + if not is_transparent_input_node(input_n, partition, removed_nodes): + transparent_input_nodes[node] = False + return False + transparent_input_nodes[node] = True + return True + transparent_input_nodes[node] = False + return False + + def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): + if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + return True + if node in transparent_output_nodes: + return transparent_output_nodes[node] + if is_non_compute_node(node): + for output_n in node.users: + if not is_transparent_output_node(output_n, partition, removed_nodes): + transparent_output_nodes[node] = False + return False + transparent_output_nodes[node] = True + return True + transparent_output_nodes[node] = False + return False + + for partition in partitions: + # Note it's ok to use `set` here, since we are only query if a node + # has been removed. We are NEVER going to iterate on nodes inside + # the set. + remove_node: Set[Node] = set() + for node in partition.nodes: + if is_non_compute_node(node) and \ + (is_transparent_input_node(node, partition.nodes, remove_node) or + is_transparent_output_node(node, partition.nodes, remove_node)): + remove_node.add(node) + + if len(remove_node) != 0: + partition.nodes = partition.nodes - remove_node + + def partition_and_fuse(self) -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions) + return fused_gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0eab25c476c5a061adbe4794bdfc2e09bcdbd6 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py @@ -0,0 +1,217 @@ +import abc +import typing as t + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from .shape_prop import TensorMetadata +from .tools_common import get_node_target, CALLABLE_NODE_OPS + + +__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain'] + +# fx.Node.target typename, as returned by `get_node_target()` +TargetTypeName = str + +# Arguments' dtypes for a given node, see `OperatorSupport` +SupportedArgumentDTypes = t.Optional[ + t.Tuple[ + t.Sequence[t.Sequence[torch.dtype]], + t.Dict[str, t.Sequence[torch.dtype]], + ] +] + +SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] + + +@compatibility(is_backward_compatible=False) +class OperatorSupportBase(abc.ABC): + """Interface for determining if a fx.Node is supported by a backend""" + @abc.abstractmethod + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + raise NotImplementedError() + + +@compatibility(is_backward_compatible=False) +class OperatorSupport(OperatorSupportBase): + """ + `_support_dict` maps node.target typename to supported inputs dtypes. + + node.target typename is retrieved using helper function `get_node_target()` + + If supported inputs dtypes is None, it means any dtype is supported, else + we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). + + The first tuple ([dtypes], ...) indicates what dtypes are supported for + inputs in node.args and the second dict {"name": [dtypes], ...} indicates + what dtypes are supported for inputs in node.kwargs. + + For inputs in args, if we don't want to check it, we can put None there, + e.g. (None, [torch.float]) indicates that we don't care about the type of + the first input in args. And for inputs in kwargs, if not listed, will not + be checked. + """ + + _support_dict: SupportDict + + def __init__( + self, + support_dict: t.Optional[SupportDict] = None + ): + self._support_dict = support_dict or {} + + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + """ + Args: + `submodules`: mapping from module name to the module. This can be + retrieved by calling model.named_modules(). + + `node`: a Fx node that we want to determine whether it's supported. + + Returns: + `is_supported`: whether the arg `node` is supported. + """ + if node.op not in CALLABLE_NODE_OPS: + return True + + target = get_node_target(submodules, node) + + # Target not found in _support_dict meaning that we don't support this op at all + if target not in self._support_dict: + return False + + # The rule for target is None meaning that we accept any dtype + if self._support_dict[target] is None: + return True + + args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] + + # Check args dtypes + for i, dtypes in enumerate(args_dtypes): + if len(node.args) <= i: + break + + # None indicates we don't care about the dtype of args[i] + if dtypes is None: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.args[i], torch.fx.Node): + continue + + arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] + if arg_dtype not in dtypes: + return False + + # Check kwargs dtypes + for k, dtypes in kwargs_dtypes.items(): + if k not in node.kwargs: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.kwargs[k], torch.fx.Node): + continue + + kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] + if kwarg_dtype not in dtypes: + return False + + return True + + +# ====================================================================== +# Functional interfaces and utils for defining basic operator support logic +# and composing them into more complex ones +# ====================================================================== + +IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] + + +@compatibility(is_backward_compatible=False) +def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: + """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance + + `IsNodeSupported` has the same call signature as + `OperatorSupportBase.is_node_supported` + """ + class FunctionalOperatorSupport(OperatorSupportBase): + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return is_node_supported(submodules, node) + return FunctionalOperatorSupport() + + +@compatibility(is_backward_compatible=False) +def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns False if + any of it reports False. + """ + def _chain(submods, node) -> bool: + return all( + x.is_node_supported(submods, node) + for x in op_support + ) + return create_op_support(_chain) + + +@compatibility(is_backward_compatible=False) +def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns True if + any of it reports True. + """ + def _any_chain(submods, node) -> bool: + return any( + x.is_node_supported(submods, node) + for x in op_support + ) + return create_op_support(_any_chain) + + +@compatibility(is_backward_compatible=False) +class OpSupports: + """A set of atomic `OperatorSupportBase` instances that can be combined together + to form more complex operator support logic. + """ + @classmethod + def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: + """Report a node as non-supported, if any of its arguments is of dtype""" + + def _decline_if_input_dtype( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + for arg in node.all_input_nodes: + arg_dtype = _get_arg_dtype(arg) + if arg_dtype == dtype: + return False + return True + return create_op_support(_decline_if_input_dtype) + + @classmethod + def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: + """ + If a node has a name that is in the disallow set, reported it as non-supported. + """ + def _decline_if_node_in_names( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + if node.name in disallow_set: + return False + else: + return True + return create_op_support(_decline_if_node_in_names) + + +def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: + assert isinstance(arg, torch.fx.Node) + tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] + dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] + return dtype diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..5979e29fcc6b2650a1f73be4845e2ad3dcda0920 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py @@ -0,0 +1,66 @@ +from torch.fx.graph_module import GraphModule +from typing import Any, Callable, Dict, List, Tuple, Type +import torch +import torch.nn as nn + +from torch.fx._compatibility import compatibility + +__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +@compatibility(is_backward_compatible=False) +def default_matching(name: str, target_version: int) -> str: + """Default matching method + """ + return name + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + ), + torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + +@compatibility(is_backward_compatible=False) +def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: Dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + return attrs_for_lowering + +@compatibility(is_backward_compatible=False) +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. + """ + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..3da0fdd76dcf226c7b1ec83c2ec51bb74bb6a451 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py @@ -0,0 +1,195 @@ +# mypy: ignore-errors + +import torch +import torch.fx +import traceback + +from torch._dispatch.python import enable_python_dispatcher +from torch.fx.node import Node, map_aggregate +from typing import Any, Tuple, NamedTuple, Optional, Dict +from torch.fx._compatibility import compatibility +from torch._guards import detect_fake_mode + +__all__ = ['TensorMetadata', 'ShapeProp'] + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape : torch.Size + dtype : torch.dtype + requires_grad : bool + stride : Tuple[int, ...] + memory_format : Optional[torch.memory_format] + + # Quantization metadata + is_quantized : bool + qparams: Dict[str, Any] + +def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + + memory_format = None + + if include_contiguity: + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + for query_format in memory_formats: + if result.is_contiguous(memory_format=query_format): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: Dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + +@compatibility(is_backward_compatible=True) +class ShapeProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and + record the shape and type of the result + into the corresponding node. + + Example: + In this example, we record the shape + and data type of a module given + an example input ``torch.randn(50, D_in)``. + We print the name, shape and dtype of each node. + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super().__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + N, D_in, H, D_out = 64, 1000, 100, 10 + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + model = TwoLayerNet(D_in, H, D_out) + gm = torch.fx.symbolic_trace(model) + sample_input = torch.randn(50, D_in) + ShapeProp(gm).propagate(sample_input) + + for node in gm.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, + node.meta['tensor_meta'].shape) + + The output of this code is: + + x torch.float32 torch.Size([50, 1000]) + linear1 torch.float32 torch.Size([50, 100]) + clamp_1 torch.float32 torch.Size([50, 100]) + linear2 torch.float32 torch.Size([50, 10]) + output torch.float32 torch.Size([50, 10]) + + Args: + module (GraphModule): The module to be executed + fake_mode (FakeTensorMode): A fake mode for copying the gm + + """ + def __init__(self, gm, fake_mode=None): + super().__init__(gm) + if fake_mode is None: + fake_mode = detect_fake_mode() + if fake_mode is not None: + from torch._dynamo.utils import deepcopy_to_fake_tensor + # Note: + # We need fake execution cause the inputs are fake, however, we cannot fakify the module + # - because we need to write to the tensor_meta of the real module. So we fakify to + # produce a result (L131 below), to extract tensor meta, and then keep going. + # + # If we were to fakify, we would write to the wrong node, and then downstream fusion + # would be missing the tensor_meta. + # + # See torch/_inductor/overrides.py for where this is called upstream of fusion. + self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) + self.fake_mode = fake_mode + else: + self.fake_module = None + self.fake_mode = None + + self.real_module = self.module + + def run_node(self, n : Node) -> Any: + try: + if self.fake_module is not None: + # Hacky swap. Alternatively, we could do this with overriding + # call_module and get_attr. + self.module = self.fake_module + try: + if self.fake_mode is not None: + with self.fake_mode, enable_python_dispatcher(): + result = super().run_node(n) + else: + result = super().run_node(n) + finally: + self.module = self.real_module + except Exception as e: + traceback.print_exc() + raise RuntimeError( + f"ShapeProp error for: node={n.format_node()} with " + f"meta={n.meta}" + ) from e + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['tensor_meta'] = meta + + n.meta['type'] = type(result) + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + if self.fake_mode is not None: + fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] + else: + fake_args = args + return super().run(*fake_args) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1282081af67b533eb3eb54adc7c3200d10cc11a9 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py @@ -0,0 +1,302 @@ +import copy +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import map_arg +from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module + +from .tools_common import NodeList + +__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] + + +@compatibility(is_backward_compatible=False) +def getattr_recursive(obj, name): + for layer in name.split("."): + if hasattr(obj, layer): + obj = getattr(obj, layer) + else: + return None + return obj + + +@compatibility(is_backward_compatible=False) +def setattr_recursive(obj, attr, value): + if "." not in attr: + setattr(obj, attr, value) + else: + layer = attr.split(".") + setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) + + +@compatibility(is_backward_compatible=False) +@dataclass +class Component: + """ + A component serves as a container for a subgraph we want to create afterwards. + """ + + graph: torch.fx.Graph + order: int + name: str + + # Stores the placeholder nodes in `graph`. + input_placeholders: List = field(default_factory=list) + + # Store the nodes in original graph that are placeholder in `graph`. + orig_inputs: List = field(default_factory=list) + + # Store the nodes in original graph that are outputs in `graph`. + orig_outputs: List = field(default_factory=list) + + # Mapping from get_attr node in original graph to get_attr node in `graph`. + getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) + constructor_args: List[str] = field(default_factory=list) + gm: Optional[torch.fx.GraphModule] = None + + +@compatibility(is_backward_compatible=False) +def split_by_tags( + gm: torch.fx.GraphModule, + tags: List[str], + return_fqn_mapping: bool = False, + return_tuple: bool = False, + GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, +) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: + """ + Splits a GraphModule using tags on its graph nodes. We honor the order of + tags. For example, we have tags = ["a", "b", "c"], the function will create + the initial submodules in the order of "a", "b", "c". + + To set a tag: + gm.graph.nodes[idx].tag = "mytag" + + This will result in all nodes with the same tag being extracted and placed in their + own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder + and output nodes are created when needed while get_attr nodes get copied to submodules + where they are used. + + Given the following module def: + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(...) + self.linear2 = torch.nn.Linear(...) + self.linear3 = torch.nn.Linear(...) + + def forward(self, in1, in2): + r1 = self.linear1(in1) + r2 = self.linear2(in2) + r3 = torch.cat([r1, r2]) + return self.linear3(r3) + + Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: + + ro: + def forward(self, in1): + self = self.root + linear1 = self.linear1(in1) + return linear1 + + main: + def forward(self, in2, linear1): + self = self.root + linear2 = self.linear2(in2) + cat_1 = torch.cat([linear1, linear2]) + linear3 = self.linear3(cat_1) + return linear3 + + main: + def forward(self, in1, in2): + self = self.root + ro_0 = self.ro_0(in1) + main_1 = self.main_1(in2, ro_0) + return main_1 + + Returns: + split_gm: torch fx graph after split + orig_to_split_fqn_mapping: a map between the original fqn and the fqn + after split for call_module and get_attr. + """ + + def flatten(x: torch.fx.node.Argument) -> NodeList: + """ + Stores nodes in x to a list and returns the list. + """ + r: NodeList = [] + map_arg(x, r.append) + return r + + # Mapping from node in original module to node in created submodule. + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + + # Mapping from node in original module or created submodules to + # corresponding component. + node_to_component: Dict[torch.fx.Node, Component] = {} + + # Mapping from tag to the corresponding component. + tag_to_component: Dict[str, Component] = {} + + # Stores all components. + all_components: List[Component] = [] + + # Stores nodes that will be used in main graph. + used_in_main: Dict[torch.fx.Node, None] = {} + + # Main graph after split. + main_g = torch.fx.Graph() + + # Mapping from node in original module to node in main graph after split. + main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + + # Output node of original module. + output_node: Optional[torch.fx.Node] = None + + # Create a component for each tag, we don't expect to create other components afterwards. + for tag in tags: + comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") + all_components.append(comp) + tag_to_component[tag] = comp + + # Traverse the nodes in original graph and take care of them. + for node in gm.graph.nodes: + if node.op == "output": + if output_node is not None: + raise RuntimeError("Multiple output nodes in graph!") + output_node = node + continue + + # Placeholders in the original graph get copied to main graph. + if node.op == "placeholder": + main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) + main_remapping[node].meta = copy.copy(node.meta) + continue + + # Get_attr nodes are ignored because we are not tagging them. + # Instead, we copy them directly to the submodules use them afterwards. + if node.op == "get_attr": + continue + + # Now we process callable nodes which are nodes with op of call_module, + # call_function or call_method. Every callable nodes should be tagged. + assert hasattr(node, "tag") + + upstream_components = [ + node_to_component[x] + for x in flatten(node.args) + flatten(node.kwargs) + if x.op not in {"placeholder", "get_attr"} + ] + + comp = tag_to_component[node.tag] + node_to_component[node] = comp + + # Max order of upperstream components. + mx = max((c.order for c in upstream_components), default=0) + + # Expect the component for `node` has higher order then its upstream components. + assert comp.order >= mx + + # Map a input of `node` to nodes in the component's graph. + def remap_func(x): + # If input is a get_attr node, copy it to current component's graph. + # Returns the get_attr node in current component's graph. + if x.op == "get_attr": + if x not in comp.getattr_maps: + comp.getattr_maps[x] = comp.graph.get_attr( + x.target, type_expr=x.type + ) + return comp.getattr_maps[x] + + # If input is not a placeholder, it should have been put into a component + # already. If it's the current component then we return the corresponding + # node in the component. + if x.op != "placeholder" and node_to_component[x] == comp: + return node_remapping[x] + + # If input is a placeholder or it's in other components, we want to make it + # as a placeholder in current component's graph. + if x not in comp.orig_inputs: + comp.orig_inputs.append(x) + placeholder = comp.graph.placeholder(x.name, type_expr=x.type) + placeholder.meta = copy.copy(x.meta) + comp.input_placeholders.append(placeholder) + used_in_main[x] = None + + return comp.input_placeholders[comp.orig_inputs.index(x)] + + n = comp.graph.node_copy(node, remap_func) + n.tag = node.tag # type: ignore[attr-defined] + node_remapping[node] = n + node_to_component[n] = comp + + if output_node is None: + raise RuntimeError("Graph had no output node!") + + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + # We don't need components mapping for nodes of type "get_attr" + # that are consumed by the output. Only need to make sure we create + # corresponding counterparts in the resulting graph. + main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) + else: + # All component results consumed by the output node should be + # marked as "used in main". + used_in_main[x] = None + + # If a node is used in main graph then we mark it as an output in the component + # it belongs to. + for n in used_in_main: + if n.op != "placeholder": + node_to_component[n].orig_outputs.append(n) + + # Now we create a graphmodule for each component. + orig_to_split_fqn_mapping: Dict[str, str] = {} + for comp in all_components: + outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) + + if return_tuple: + comp.graph.output(outs) + else: + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + comp.graph.output(outs[0] if len(outs) == 1 else outs) + + comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( + gm, subgraph=comp.graph, comp_name=comp.name + ) + orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) + + # Create a call_module node in main graph. + main_node = main_g.call_module( + comp.name, + args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), + kwargs=None, + ) + + if len(outs) == 1 and not return_tuple: + main_remapping[comp.orig_outputs[0]] = main_node + else: + for i, o in enumerate(comp.orig_outputs): + # Use Proxy to record getitem access. + main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] + + main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) + main_root = HolderModule({comp.name: comp.gm for comp in all_components}) + main_g._codegen = gm.graph._codegen + + # If the output nodes consumes get_attr directly in the original graph, + # then we need to make sure get_attr is copied to the new graph. + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] + + result_gm = GraphModuleCls(main_root, main_g) + if return_fqn_mapping: + return result_gm, orig_to_split_fqn_mapping + + return result_gm diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e5a779e36a9cd745b54a96f3250fbfaed438543 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c30407ffd8d55d4c8c2b017158ee9947e72e2e63 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd030337df4a1672e5e937366116a2c5c36b6c8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/common.py @@ -0,0 +1,95 @@ +from typing import Dict, Tuple + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph + +from torch.fx.graph_module import GraphModule +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.nn import Module + + +__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] + + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module( + gm: GraphModule, + subgraph: Graph, + comp_name: str = "", + class_name: str = "GraphModule", +) -> Tuple[GraphModule, Dict[str, str]]: + """ + Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + comp_name (str): name for the new component + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + orig_to_split_fqn_mapping: Dict[str, str] = {} + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping + + +@compatibility(is_backward_compatible=False) +def compare_graphs(left: Graph, right: Graph) -> bool: + """ + Return True if two graphs are identical, i.e they + - have the same number of outputs in the same order + - have the same number of inputs in the same order + - have the same set of nodes, and identical connectivity + """ + + matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) + matches = matcher.match(right) + + return len(matches) > 0 diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..00415d10fee75b350d958ddc9c1ba56e3decc40c --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/fx/passes/utils/matcher_utils.py @@ -0,0 +1,400 @@ +from dataclasses import dataclass, field +from collections import defaultdict +import copy +import torch +from torch.fx import ( + Node, + Graph, +) +from torch.fx._compatibility import compatibility +from typing import Dict, List, Set, Any, Union, Tuple +import logging +import os + +__all__ = ['SubgraphMatcher', 'InternalMatch'] + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + +logger = _init_logger() + +@compatibility(is_backward_compatible=False) +@dataclass +class InternalMatch: + # Nodes from which the match was found + anchors: List[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] = field(default_factory=dict) + + # nodes in target graph that are matched placeholder in pattern + placeholder_nodes: List[Node] = field(default_factory=list) + + # nodes in matched subgraph returned by output + returning_nodes: List[Node] = field(default_factory=list) + + # map from a string name to a node in the target graph + # only available if the matcher is `SubgraphMatcherWithNameNodesMap` + name_node_map: Dict[str, Node] = field(default_factory=dict) + + def __copy__(self): + return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(), + placeholder_nodes=self.placeholder_nodes.copy(), + returning_nodes=self.returning_nodes.copy()) + +@compatibility(is_backward_compatible=False) +class SubgraphMatcher: + def __init__(self, pattern: Graph, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False) -> None: + """ + Args: + pattern: the targeted matching pattern, represented in fx.Graph. + match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. + If False, output node is ignored during match. + match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of + the targeted pattern. If False, placeholder nodes will be used a wildcard. + remove_overlapping_matches: If True, in the case of overlapping matches, only the first match + will be returned. + ignore_literals: If True, will not check if literals are equal and + will instead treat them as wildcards. + """ + + self.pattern = pattern + self.match_output = match_output + self.match_placeholder = match_placeholder + self.remove_overlapping_matches = remove_overlapping_matches + self.ignore_literals = ignore_literals + + if len(pattern.nodes) == 0: + raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern") + + for node in pattern.nodes: + if node.op != "output": + assert len(node.users) > 0, \ + "SubgraphMatcher cannot be initialized with an pattern with dead code" + + # TODO: assert pattern is a connected graph + + self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"] + output_node = next(iter(reversed(pattern.nodes))) + # nodes returned by outputs + self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes + + self.pattern_anchors: List[Node] = [] + if match_output: + self.pattern_anchors = [output_node] + else: + # If a node has output_node as the ONLY user, then this node is a graph sink, + # and should be matched against as an anchor + self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1] + + def _match_attributes(self, pn: Node, gn: Node) -> bool: + # Attributes matching is complicated. Right now we only support matching constant tensor + assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." + assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." + + # TODO(tmanlaibaatar) should probably make this actual API + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + pn_value = _getattr(pn.graph.owning_module, pn.target) + gn_value = _getattr(gn.graph.owning_module, gn.target) + + if type(pn_value) != type(gn_value): + return False + + # Don't require exact match on tensor values. + if isinstance(pn_value, torch.Tensor): + return isinstance(gn_value, torch.Tensor) + else: + raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") + return False + + def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: + # if exact match for placeholder is not required, then use placeholder as a wildcard + if not self.match_placeholder and pn.op == "placeholder": + return True + + if pn.op == gn.op: + if pn.op == "placeholder" or pn.op == "output": + return True + elif pn.op == "get_attr": + return self._match_attributes(pn, gn) + return pn.target == gn.target + return False + + def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + + # Placeholders can be used by other nodes in the graphs + lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"} + + for gn, pn in lookup.items(): + # nodes returned by output are allowed to be used in other areas of the graph + if pn in self.pattern_returning_nodes: + continue + + for user in gn.users: + # If this node has users that were not in `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]: + non_overlapping_matches: List[InternalMatch] = list() + nodes_matched: Set[Node] = set() + + for match in matches: + found_overlap = False + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"} and gn in nodes_matched: + found_overlap = True + break + + if not found_overlap: + non_overlapping_matches.append(match) + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"}: + nodes_matched.add(gn) + return non_overlapping_matches + + def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: + assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node" + + if isinstance(pn, Node) and not isinstance(gn, Node): + if pn.op == "placeholder": + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + match.nodes_map[pn] = gn + return True + else: + return False + elif not isinstance(pn, Node) and isinstance(gn, Node): + return False + else: + return type(gn) == type(pn) and gn == pn + + def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: + logger.info(" matching %s to %s", pn, gn) + + assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}") + + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + # TODO: use a more efficient way to check if gn is matched before: two-way dict + if gn in match.nodes_map.values(): + return False + + if not self._nodes_are_equal(pn, gn): + return False + + # Optimistically mark `pn` as a match for `gn`, and save a local copy of match + saved_match = copy.copy(match) + match.nodes_map[pn] = gn + + # Placeholder is a wildcard and can be matched with any python object + # (including list/tuple) + if pn.op == "placeholder": + return True + + # Recursively traverse upwards to check if `pn` is a true + # match for `gn` + match_found = True + + def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool: + if len(args1) != len(args2): + return False + + for a1, a2 in zip(args1, args2): + if isinstance(a1, Node) and isinstance(a2, Node): + matched = self._match_nodes(a1, a2, match) + elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): + matched = _match_args(a1, a2) + else: + matched = self._match_literals(a1, a2, match) or self.ignore_literals + + if not matched: + return False + + return True + + # Flatten all args/kwargs into 1 list of args + pn_args, gn_args = None, None + if ( + (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and + pn.op == "call_function" and + isinstance(pn.target, torch._ops.OpOverload) + ): + args_schema = pn.target._schema.arguments + + def get_all_arguments(orig_args, orig_kwargs): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + pn_args = get_all_arguments(pn.args, pn.kwargs) + gn_args = get_all_arguments(gn.args, gn.kwargs) + + elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()): + pn_args = list(pn.args) + gn_args = list(gn.args) + pn_args.extend(list(pn.kwargs.values())) + gn_args.extend(list(gn.kwargs.values())) + else: + match_found = False + + match_found = ( + match_found and + pn_args is not None and + gn_args is not None and + _match_args(pn_args, gn_args) + ) + + if not match_found: + # revert to saved_match before matching with current node + match = copy.copy(saved_match) + return False + + return True + + def match(self, graph: Graph) -> List[InternalMatch]: + """ + Returns: + The matched subgraphs. + Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder + and nodes returned by output) can only be consumed by nodes within the matched subgraph. + + Subgraph pattern matcher is implemented with the backtracking style in the following steps: + + 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes + are the "sinks" (nodes with no user other than the output node) of the pattern graph. + One pattern graph could have multiple anchors if it has multiple return values. + + 2. In the target graph, we identify the potential candidate nodes that can be matched + with each anchor. These anchor-candidate pairs are the starting points for + pairwise per-node matching. + + 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both + pattern and target graphs. For every pattern nodes along traversal path, we compare it + against the target nodes. In case any comparison failed, the match for this anchor-candidate + pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` + for more details. + + 4. In the case of multiple anchors, every anchor will need to find a match using step 3. + In addition, the matches found between anchors need to have a common intersection node + in order for the match to be valid. This is implemented with backtracking. See `backtracking` + for more details. + + Notice: graph traversal must be done in the reverser order because a tensor can have multiple + consumers, but can only have a single producer. Only with reverser order, we can we jointly + traverse the pattern and target graph in a deterministic path. + + Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, + in practice, it's unlikely to blow up. + + """ + from torch.fx.passes.utils.fuser_utils import validate_partition + + # find candidate nodes to match with pattern anchors + match_candidates: Dict[Node, List[Node]] = defaultdict(list) + for pattern_anchor in self.pattern_anchors: + for node in graph.nodes: + if self._nodes_are_equal(pattern_anchor, node): + match_candidates[pattern_anchor].append(node) + match_candidates_list = list(match_candidates.items()) + + logger.info("Initial match_candidates_list: %s\n", match_candidates_list) + + matches: List[InternalMatch] = [] + + def backtracking(anchor_index, match): + if anchor_index == len(match_candidates_list): + match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes] + match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes] + matches.append(match) + + logger.info("Found a match: %s\n", match) + return + + pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] + saved_match = copy.copy(match) + + for node in candidate_nodes: + logger.info("Trying to match anchor %s to %s", pattern_anchor, node) + + match_found = self._match_nodes(pattern_anchor, node, match) + if match_found: + # match next anchor + backtracking(anchor_index + 1, match) + else: + logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node) + + # revert to saved_match before matching with current anchor + match = copy.copy(saved_match) + + match = InternalMatch(anchors=self.pattern_anchors) + if match_candidates_list: + backtracking(0, match) + + # filter out the matches where the subgraph is not fully_contained + before = len(matches) + matches = [match for match in matches if self._is_contained(match.nodes_map)] + after = len(matches) + if before != after: + logger.info("Filtered out %s matches because they are not fully contained", before - after) + + # filter out the matches that form a cycle if the subgraph is fused + valid_matches = [] + for match in matches: + matched_compute_nodes = \ + [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}] + if validate_partition(matched_compute_nodes): + valid_matches.append(match) + if len(valid_matches) != len(matches): + logger.info("Filtered out %s matches because \ + matched subgraph would form a cycle if fused", len(matches) - len(valid_matches)) + + if self.remove_overlapping_matches: + before = len(valid_matches) + matches = self._remove_overlapping_matches(valid_matches) + after = len(matches) + if before != after: + logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after) + + logger.info("Matches returned: %s", matches) + + return matches diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd2870ca86b4e592f20e0ed25117ad1652be83b4 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/multiprocessing/__pycache__/pool.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d681233e3e204618f4406800277f41264cffa2be Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd83d88a3e3e72385726851b1fdd5fc09086a473 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__init__.py @@ -0,0 +1,87 @@ +from .quantize import * # noqa: F403 +from .observer import * # noqa: F403 +from .qconfig import * # noqa: F403 +from .fake_quantize import * # noqa: F403 +from .fuse_modules import fuse_modules +from .stubs import * # noqa: F403 +from .quant_type import * # noqa: F403 +from .quantize_jit import * # noqa: F403 + +# from .quantize_fx import * +from .quantization_mappings import * # noqa: F403 +from .fuser_method_mappings import * # noqa: F403 + + +def default_eval_fn(model, calib_data): + r""" + Default evaluation function takes a torch.utils.data.Dataset or a list of + input Tensors and run the model on the dataset + """ + for data, target in calib_data: + model(data) + + +__all__ = [ + "QuantWrapper", + "QuantStub", + "DeQuantStub", + # Top level API for eager mode quantization + "quantize", + "quantize_dynamic", + "quantize_qat", + "prepare", + "convert", + "prepare_qat", + # Top level API for graph mode quantization on TorchScript + "quantize_jit", + "quantize_dynamic_jit", + "_prepare_ondevice_dynamic_jit", + "_convert_ondevice_dynamic_jit", + "_quantize_ondevice_dynamic_jit", + # Top level API for graph mode quantization on GraphModule(torch.fx) + # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx + # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', + "QuantType", # quantization type + # custom module APIs + "get_default_static_quant_module_mappings", + "get_static_quant_module_class", + "get_default_dynamic_quant_module_mappings", + "get_default_qat_module_mappings", + "get_default_qconfig_propagation_list", + "get_default_compare_output_module_list", + "get_quantized_operator", + "get_fuser_method", + # Sub functions for `prepare` and `swap_module` + "propagate_qconfig_", + "add_quant_dequant", + "swap_module", + "default_eval_fn", + # Observers + "ObserverBase", + "WeightObserver", + "HistogramObserver", + "observer", + "default_observer", + "default_weight_observer", + "default_placeholder_observer", + "default_per_channel_weight_observer", + # FakeQuantize (for qat) + "default_fake_quant", + "default_weight_fake_quant", + "default_fixed_qparams_range_neg1to1_fake_quant", + "default_fixed_qparams_range_0to1_fake_quant", + "default_per_channel_weight_fake_quant", + "default_histogram_fake_quant", + # QConfig + "QConfig", + "default_qconfig", + "default_dynamic_qconfig", + "float16_dynamic_qconfig", + "float_qparams_weight_only_qconfig", + # QAT utilities + "default_qat_qconfig", + "prepare_qat", + "quantize_qat", + # module transformations + "fuse_modules", +] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..862fd0e145251e9b3827f52073be020958276877 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/_numeric_suite_fx.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3533b4c3b428f77518e45dfb0afad62ed48d282c Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/fuse_modules.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc3a73b6f607c2760fd1471220ff636974e312b Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quant_type.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4dcd02e8961a18f7c336d164929d615a56b0f94 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/__pycache__/quantization_mappings.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_numeric_suite_fx.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_numeric_suite_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..55cd7085740d0ce8de79491acbfc4888ebba21f8 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_numeric_suite_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/ns/_numeric_suite_fx.py`, while adding an import statement +here. +""" + +from torch.ao.ns._numeric_suite_fx import ( + _add_loggers_impl, + _add_loggers_one_model, + _add_shadow_loggers_impl, + _extract_logger_info_one_model, + _extract_weights_impl, + _extract_weights_one_model, + add_loggers, + add_shadow_loggers, + extend_logger_results_with_comparison, + extract_logger_info, + extract_shadow_logger_info, + extract_weights, + NSTracer, + OutputLogger, + RNNReturnType, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_quantized_conversions.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_quantized_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7670ea48026f22d040b7d1c73e9330ee9ece3e --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/_quantized_conversions.py @@ -0,0 +1,132 @@ +import torch + + +# Pack pairs of int4 values into int8, in row major order; first int4 +# value goes into lower order bits, and second int4 value into higher +# order bits of resulting int8 value. +def pack_int4_to_int8(weight): + assert weight.dim() == 2 + assert weight.shape[1] % 2 == 0 + assert weight.dtype == torch.int8 + return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF) + + +# Unpack quandruples of bits in int8 values into int4 values, in row +# major order; lower 4 bits go into first int4 value goes, and upper 4 +# bits go into second int4 value. +def unpack_int8_to_int4(weight): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view( + weight.shape[0], 2 * weight.shape[1] + ) + + +# Transpose the weight matrix, and then reorder its elements according +# to underlying requirements of CUTLASS library, so that it could be +# used for CUTLASS-based mixed datatypes linear operation. +def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( + weight, dtypeq, transpose=False +): + assert weight.dim() == 2 + assert weight.dtype == torch.int8 + assert dtypeq == torch.int8 or dtypeq == torch.quint4x2 + assert weight.device.type == "cuda" + + device = weight.device + + # subbyte_transpose + if not transpose: + if dtypeq == torch.int8: + outp = weight.T + elif dtypeq == torch.quint4x2: + outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T) + else: + outp = weight + + ncols, nrows = outp.shape # type: ignore[possibly-undefined] + assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0 + assert ncols % 64 == 0 + + # permute_B_rows_for_mixed_gemm + # (permute cols actually, as transpose is applied first here) + if dtypeq == torch.quint4x2: + cols_permuted = ( + torch.tensor( + [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + else: + cols_permuted = ( + torch.tensor( + [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15], + device=device, + ) + + (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand( + nrows // 16, 16 + ) + ).view(-1) + outp = outp.index_copy(1, cols_permuted, outp) + + # interleave_column_major_tensor + magic0 = 4 if dtypeq == torch.quint4x2 else 2 + magic1 = 32 // magic0 + + tmp0 = ( + (torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0)) + .view(-1, 1) + .repeat(1, nrows // 4 * magic0) + .view(-1) + ) + tmp1 = ( + (torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1)) + .view(-1, 1) + .repeat(1, magic1) + .view(-1) + .repeat(ncols) + ) + tmp2 = ( + (torch.arange(0, magic0, device=device) * magic1) + .view(-1, 1) + .repeat(1, nrows // 4) + .view(-1) + .repeat(ncols // magic0) + ) + tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1) + + outp_offsets = tmp0 + tmp1 + tmp2 + tmp3 + + tmp = outp.view(-1).view(torch.int32) + outp = torch.zeros_like(tmp) + outp.scatter_(0, outp_offsets, tmp) + outp = outp.view(weight.dtype) + + # add_bias_and_interleave_quantized_tensor_inplace + tmp = outp.view(-1) + + outp = torch.empty_like(tmp) + if dtypeq == torch.int8: + tmp = (tmp.to(torch.int) + 128).to(tmp.dtype) + outp[0::4] = tmp[0::4] + outp[1::4] = tmp[2::4] + outp[2::4] = tmp[1::4] + outp[3::4] = tmp[3::4] + elif dtypeq == torch.quint4x2: + tmp0 = ((tmp & 0xF) + 8) & 0xF + tmp0 = (tmp0[1::2] << 4) | tmp0[0::2] + tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF + tmp1 = (tmp1[1::2] << 4) | tmp1[0::2] + outp[0::4] = tmp0[0::2] + outp[1::4] = tmp0[1::2] + outp[2::4] = tmp1[0::2] + outp[3::4] = tmp1[1::2] + + if dtypeq == torch.quint4x2: + nrows *= 2 + ncols //= 2 + + return outp.view(nrows, ncols).view(torch.uint8) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fuser_method_mappings.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fuser_method_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb13ac96271fa7b926cc703918984760e6ede15 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fuser_method_mappings.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuser_method_mappings.py`, while adding an import statement +here. +""" +from torch.ao.quantization.fuser_method_mappings import ( + _DEFAULT_OP_LIST_TO_FUSER_METHOD, + fuse_conv_bn, + fuse_conv_bn_relu, + fuse_linear_bn, + get_fuser_method, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__init__.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c01cbd457374c27e40b07daca5ae1644a701767d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__init__.py @@ -0,0 +1,15 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.convert import convert +from torch.ao.quantization.fx.fuse import fuse + +# omitting files that's unlikely to be used right now, for example +# the newly added lower_to_fbgemm etc. +from torch.ao.quantization.fx.prepare import prepare diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10001c14c6fd373a26332090b5f713b8829fac63 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/__init__.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/convert.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/convert.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44fe81583bafe296db053984dd028802ea5c773d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/convert.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13a84111724dfe88f727dd6b996df2b39599f3ae Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/fuse.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73fb165ac41dfaba1437de339ea70512484d93a5 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/graph_module.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2cec084624d8eaa012c8096dd66931d0fab63ca Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/match_utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfac8911f796ce4019a150c3f45d3010aa3a02e0 Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/prepare.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835201a1b3cc5e1a22c8bc6be6cf172c4edc8f3a Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/quantization_types.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/utils.cpython-311.pyc b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128e10c6a06ee98a8ce43d81aecc64e764c8eb5d Binary files /dev/null and b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/__pycache__/utils.cpython-311.pyc differ diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/convert.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6ac350602bb7a97c773a3a09fec0780483379f --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/convert.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.convert import convert diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/graph_module.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..a71e980a57ba141bdc5bbe9b283d69582eb8fd82 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/graph_module.py @@ -0,0 +1,17 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.graph_module import ( + _is_observed_module, + _is_observed_standalone_module, + FusedGraphModule, + GraphModule, + ObservedGraphModule, + ObservedStandaloneGraphModule, + QuantizedGraphModule, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/match_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/match_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8b49f7c645d8d1bc3a154d62a1295a90b155f986 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/match_utils.py @@ -0,0 +1,14 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.match_utils import ( + _find_matches, + _is_match, + _MatchResult, + MatchAllNode, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/pattern_utils.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/pattern_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26954833bb48eb5a807ac31cc558c5282cb63201 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/pattern_utils.py @@ -0,0 +1,34 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.pattern_utils import ( + _register_fusion_pattern, + _register_quant_pattern, + get_default_fusion_patterns, + get_default_output_activation_post_process_map, + get_default_quant_patterns, + QuantizeHandler, +) + +# QuantizeHandler.__module__ = _NAMESPACE +_register_fusion_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_fusion_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +_register_quant_pattern.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_quant_patterns.__module__ = "torch.ao.quantization.fx.pattern_utils" +get_default_output_activation_post_process_map.__module__ = ( + "torch.ao.quantization.fx.pattern_utils" +) + +# __all__ = [ +# "QuantizeHandler", +# "_register_fusion_pattern", +# "get_default_fusion_patterns", +# "_register_quant_pattern", +# "get_default_quant_patterns", +# "get_default_output_activation_post_process_map", +# ] diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/prepare.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..ca65dcc04dd0021f0065892ca86e209a1c218473 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/prepare.py @@ -0,0 +1,9 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.prepare import prepare diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/quantization_patterns.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/quantization_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..34ee88a4713c5d7016d8a50193555b6ec7c3dfe2 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/fx/quantization_patterns.py @@ -0,0 +1,47 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +appropriate files under `torch/ao/quantization/fx/`, while adding an import statement +here. +""" +from torch.ao.quantization.fx.quantize_handler import ( + BatchNormQuantizeHandler, + BinaryOpQuantizeHandler, + CatQuantizeHandler, + ConvReluQuantizeHandler, + CopyNodeQuantizeHandler, + CustomModuleQuantizeHandler, + DefaultNodeQuantizeHandler, + EmbeddingQuantizeHandler, + FixedQParamsOpQuantizeHandler, + GeneralTensorShapeOpQuantizeHandler, + LinearReLUQuantizeHandler, + QuantizeHandler, + RNNDynamicQuantizeHandler, + StandaloneModuleQuantizeHandler, +) + +QuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BinaryOpQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CatQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +ConvReluQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +LinearReLUQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +BatchNormQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +EmbeddingQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +RNNDynamicQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +DefaultNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +FixedQParamsOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +CopyNodeQuantizeHandler.__module__ = "torch.ao.quantization.fx.quantization_patterns" +CustomModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +GeneralTensorShapeOpQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) +StandaloneModuleQuantizeHandler.__module__ = ( + "torch.ao.quantization.fx.quantization_patterns" +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/observer.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/observer.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6c7c1917c83433fc19f016140b25d060284535 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/observer.py @@ -0,0 +1,36 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/observer.py`, while adding an import statement +here. +""" +from torch.ao.quantization.observer import ( + _is_activation_post_process, + _is_per_channel_script_obs_instance, + _ObserverBase, + _PartialWrapper, + _with_args, + _with_callable_args, + ABC, + default_debug_observer, + default_dynamic_quant_observer, + default_float_qparams_observer, + default_histogram_observer, + default_observer, + default_per_channel_weight_observer, + default_placeholder_observer, + default_weight_observer, + get_observer_state_dict, + HistogramObserver, + load_observer_state_dict, + MinMaxObserver, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, + NoopObserver, + ObserverBase, + PerChannelMinMaxObserver, + PlaceholderObserver, + RecordingObserver, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_fx.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..649142c7a7eee9885d96b37f70e582f3ea9a9f8d --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_fx.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_fx.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fx.graph_module import ObservedGraphModule +from torch.ao.quantization.quantize_fx import ( + _check_is_graph_module, + _convert_fx, + _convert_standalone_module_fx, + _fuse_fx, + _prepare_fx, + _prepare_standalone_module_fx, + _swap_ff_with_fxff, + convert_fx, + fuse_fx, + prepare_fx, + prepare_qat_fx, + QuantizationTracer, + Scope, + ScopeContextManager, +) diff --git a/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_jit.py b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..aa627dc7bb51ef7ea1fde7e2e5da283c9f6c8900 --- /dev/null +++ b/tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/quantization/quantize_jit.py @@ -0,0 +1,26 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quantize_jit.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.quantize_jit import ( + _check_forward_method, + _check_is_script_module, + _convert_jit, + _prepare_jit, + _prepare_ondevice_dynamic_jit, + _quantize_jit, + convert_dynamic_jit, + convert_jit, + fuse_conv_bn_jit, + prepare_dynamic_jit, + prepare_jit, + quantize_dynamic_jit, + quantize_jit, + script_qconfig, + script_qconfig_dict, +)