diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e45b6e61094cf1ab0965ab8fb94fe17ba1d874 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c0d41ef74c85ca33d47cbaf40dce7c947666d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b608d05f7205e7450cfb1252af2609b00019b6b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8a0601d49b7eb2c2e15f36d72aa8e6ad56f1bf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14f8c05ece7c2170afc2c2b6a7baf2642feb0216 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..639b357bae2ae65a4ccb06b30a5c3735c60b5c7e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb73af4e3535c6255c279ff1db714d910f0d87fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06edebf4517e4ad93343ba6c5ad9decfcc44aeed Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..666663417fb2fc75ef1960ebb363a85e1f3547f1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..4693a62de24025c7bb1029297255141bf48cb07f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,558 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ + op_mod, op_gt, op_lt, op_neq, op_eq +from torch.fx.tensor_type import TensorType, Dyn + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjunction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f'And({self.conjucts})' + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + else: + return False + + def __repr__(self): + return f'Or({self.disjuncts})' + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f'Product({self.products})' + + +class T(Constraint): + """ + True + """ + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return 'True' + +class F(Constraint): + """ + False + """ + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return 'False' + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string representing the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + else: + return False + + def __repr__(self): + return f'({self.lhs} {self.op} {self.rhs})' + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \ + (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the outout + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f'{self.res} = {self.rhs1}\u2294{self.rhs2}' + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f'can-reshape({self.src}, {self.target})' + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + output: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'IndexSelect({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.dim_replace}, ' \ + f'{self.index})' + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return self.tensor_size == other.tensor_size and \ + self.dim_replace == other.dim_replace and \ + self.index == other.index and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class Transpose(Constraint): + + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + + return f' {self.output} = ' \ + f'Transpose({self.input_var}, ' \ + f'tensor_size: {self.tensor_size}, ' \ + f'{self.index1}, ' \ + f'{self.index2})' + + def __eq__(self, other): + if isinstance(other, Transpose): + return self.tensor_size == other.tensor_size and \ + self.index1 == other.index1 and \ + self.index2 == other.index2 and \ + self.output == other.output and \ + self.input_var == other.input_var + else: + return False + + +class GetItem(Constraint): + + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' + + def __eq__(self, other): + if isinstance(other, GetItem): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index == other.index and \ + self.input_var == other.input_var + else: + return False + +class GetItemTensor(Constraint): + + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return self.res == other.res and \ + self.tensor_size == other.tensor_size and \ + self.index_tuple == other.index_tuple and \ + self.input_var == other.input_var + else: + return False + +class CalcConv(Constraint): + + def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output chanel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.conv_result} =' \ + f' calc-conv({self.input_var},' \ + f' {self.c_out}, {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcConv): + return self.conv_result == other.conv_result and self.input_var == other.input_var and \ + self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class CalcMaxPool(Constraint): + + def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return f'{self.maxpool_result} =' \ + f' calc-maxpool({self.input_var},' \ + f' {self.kernel}, ' \ + f'{self.padding}, {self.stride},' \ + f' {self.dilation})' + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ + and self.kernel == other.kernel and self.padding == other.padding \ + and self.stride == other.stride and self.dilation == other.dilation \ + and self.matching_constraint == other.matching_constraint + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return self.res1 == other.res1 \ + and self.res2 == other.res2 \ + and self.input1 == other.input1 \ + and self.input2 == other.input2 + else: + return False + + def __repr__(self): + return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param flattened: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return self.start == other.start and self.end == other.end and \ + self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened + + else: + return False + + def __repr__(self): + return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f'TV({self.tvar})' + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'DV({self.c})' + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f'BV({self.c})' + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, (BVar, Conj, Disj)) + +def is_dim(d): + return isinstance(d, (DVar, int)) or d == Dyn diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..952dde662f2ab8cb2a0613b7163266dc2a421758 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1281 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import torch +import operator +import warnings +from typing import Callable, Dict, Iterable + +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ + Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ + TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.operation import \ + op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul +from torch.fx.node import Target, Node +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ + gen_bvar + +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.conv import Conv2d +from torch.nn.modules.batchnorm import BatchNorm2d + +_INFERENCE_RULES: Dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +def register_inference_rule(call_target): + def register(fn): + if call_target in _INFERENCE_RULES: + raise RuntimeError(f'Inference rule already registered for {call_target}!') + _INFERENCE_RULES[call_target] = fn + return fn + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == 'device': + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError('Not yet implemented') + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq)]) + + input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) + + input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) + + consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), + *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK)])]) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, (Node, int)) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) + + constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] + constraints += c + + return constraints, counter + + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +@register_inference_rule(torch.ones) +@register_inference_rule(torch.zeros) +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + output, counter = gen_tvar(counter) + symbols[n] = output + + if isinstance(n.args[0], Node): + input = symbols[n.args[0]] + if isinstance(input, TVar): + return [BinConstraintT(input, output, op_eq)], counter + + # then we have dimension variables + else: + for arg in n.args: + assert isinstance(symbols[arg], DVar) + my_size = [symbols[arg] for arg in n.args] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + + elif isinstance(n.args[0], tuple): + # then the tuple is the size + assert len(n.args[0]) <= 4 + my_size = [symbols[arg] for arg in n.args[0]] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) + + # or input is a tensor and we actually do the replacement + c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq)], counter + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implement the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) + else: + raise NotImplementedError('Not yet implemented') + + +@register_inference_rule(torch.nn.functional.embedding) +def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + embedding_dim_weights = symbols[n.args[1]] + + # will treat this as a static shape. So we will not use matching. + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) + embedding_dim = weight_dims[1] + constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) + return [equality_constraint] + constraints, counter + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) + + +def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): + + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + + nat_constraints) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.tensor) +def tensor_inference_rule(n: Node, symbols, constraints, counter): + """ + If the tensor is a scalar, we will skip it since we + do not support scalars yet. We will add support in the future + if it's needed. For our examples so far, scalars are not needed. + """ + return [], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + + src_var = symbols[n.args[0]] + t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq)] + + [range_check(arg_1, i)] + nat_constraints) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] + + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjunction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + if n.args[0] in symbols: + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK)] + else: + # TODO: we should figure out why there is a key-error here. + return [], counter + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError('Method not yet implemented') + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + elif isinstance(e1, TVar) and isinstance(e2, int): + # then we made the wrong assumption about the argument being a tensor + # so we should fix the assumption + warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') + + new_e1, counter = gen_dvar(counter) + symbols[n.args[0]] = new_e1 + symbols[n.args[0]] + + gt_constraint = BinConstraintD(new_e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule(operator.eq) +def eq_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + eq_tensor, counter = gen_tvar(counter) + symbols[n] = eq_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + else: + raise NotImplementedError('Method not yet implemented') + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + assert isinstance(n.args[1][3], (Node, int)) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) + dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) + dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) + dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) + + dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError('Method not yet implemented') + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError('Sort Mismatch') + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError('Method not yet implemented') + + else: + raise NotImplementedError('Method not yet implemented') + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + dim = arg if isinstance(arg, int) else symbols[arg] + res.append(dim) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError('Not yet implemented') + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj([BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq)]) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj([BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq)]) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + + op_code = None + if n.target == operator.add or n.target == torch.add: + op_code = op_add + elif n.target == operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), + BinConstraintD(0, my_output, op_leq)]) + return [c], counter + + else: + raise NotImplementedError('Method not yet implemented') + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError('Addition not yet implemented') + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.functional.layer_norm) +def layer_norm_functional(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) + + +def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) + + +@register_inference_rule("dim") # type: ignore[attr-defined] +def torch_dim_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + my_dim, counter = gen_dvar(counter) + symbols[n] = my_dim + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(my_dim, Dyn, op_eq) + + c1 = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + + c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq)]) + c1.append(c_tensor_i) + + return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter + + +@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined] +def torch_linear_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) + constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) + return [equality_constraint] + constraints, counter + + +def linear_constraints(n: Node, in_features, out_features, symbols, counter): + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + + nat_constraints) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, in_features, out_features): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv(my_conv, input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, + module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, 'graph') else graph + + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == 'placeholder': + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + + my_type = n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert 'example_value' in n.meta + my_type = TensorType(n.meta['example_value'].size()) + else: + my_type = Dyn + + c1 = BinConstraintT(my_type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == 'call_function': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'call_module': + + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, + module_instance, + self.symbol_dict, + self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + + elif n.op == 'call_method': + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + else: + raise RuntimeError(f'No inference rule registered for target {n.target}!') + + elif n.op == 'get_attr': + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = list(t.shape) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == 'output': + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..439e3d6195e654147f5f583b6b13fa9611757372 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -0,0 +1,1040 @@ +# mypy: ignore-errors +import copy +import itertools +from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK +from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ + Transpose +from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound +from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool +from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape +from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect +from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching +from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq +from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod +from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar +from torch.fx.tensor_type import TensorType, Dyn +from typing import Callable, Dict, List + +_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} + + +def register_transformation_rule(call_target): + def register(fn): + if call_target in _TRANSFORMATION_RULES: + raise RuntimeError(f'Transformation rule already registered for {call_target}!') + _TRANSFORMATION_RULES[call_target] = fn + return fn + return register + + +def valid_index(index, dims): + """ + Given a list of dimensions, checks if an index is valid in the list + """ + try: + dims[index] + return T() + except IndexError: + return F() + + +@register_transformation_rule(Transpose) +def transform_transpose(constraint, counter): + """ + Similar to a sequence of two index-selects + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index1 = valid_index(constraint.index1, dims) + is_valid_index2 = valid_index(constraint.index2, dims) + new_dims = copy.deepcopy(dims) + nat_constraints = gen_nat_constraints(dims) + + if is_valid_index1 == T() and is_valid_index2 == T(): + new_dims[constraint.index1] = dims[constraint.index2] + new_dims[constraint.index2] = dims[constraint.index1] + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + return transformed_constraint, counter + + +@register_transformation_rule(IndexSelect) +def transform_index_select(constraint, counter): + """ + The constraints consider the given tensor size, checks if the index is valid + and if so, generates a constraint for replacing the input dimension + with the required dimension + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index = valid_index(constraint.index, dims) + nat_constraints = gen_nat_constraints(dims) + + # if the index is valid then replace the input dimension with the new dimension + # otherwise the dimension will not be replaced and the clause will contain False + if is_valid_index == T(): + new_dims = copy.deepcopy(dims) + new_dims[constraint.index] = constraint.dim_replace + + transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + + # print(constraints) + return transformed_constraint, counter + + +@register_transformation_rule(GetItem) +def transform_get_item(constraint, counter): + """ + generate an equality of the form: + t = [a1, ..., an] + then generate constraints that check if the given index is valid + given this particular tensor size. + If the index is valid, generate a constraint to get the item + Note that we already handled the Dyn input case in the previous + step. + Args: + constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) + counter: variable tracking + Returns: simplified constraints for GetItem + + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + + is_valid_index = valid_index(constraint.index, dims) + + all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index] + + # if the index is valid, we generate a constraint for getting an item + # otherwise this clause will have been UNSAT due to the wrong index + if is_valid_index == T(): + all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) + + return Conj(all_constraints), counter + +def valid_index_tensor(index, dims): + """ + if the slice instances exceed the length of the dimensions + then this is a type error so we return False + """ + slice_count = 0 + for s in index: + if isinstance(s, slice): + slice_count += 1 + if slice_count > len(dims): + return F() + else: + return T() + +@register_transformation_rule(GetItemTensor) +def transform_get_item_tensor(constraint, counter): + """ + When the index is a tuple, then the output will be a tensor + TODO: we have to check if this is the case for all HF models + + The cases we are covering here are a tuple with one of: + - slice with default argument + - None + + None appends 1 to the input tensor dimensions + so each occurrence of 'None' increases the rank by 1 + + slice with default arguments does not change the rank + """ + assert isinstance(constraint.index_tuple, tuple) + + + # generate a result tensor of the expected size + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + # generate a place-holder list of the right rank + # where "slice" does not contribute to the rank and "None" does + none_c = constraint.index_tuple.count(None) + resulting_tensor_dims = (none_c + len(dims)) * [None] + + dim_index = 0 + for i in range(len(constraint.index_tuple)): + + # append 1 to the right location of the resulting tensor + if constraint.index_tuple[i] is None: + resulting_tensor_dims[i] = 1 + + elif constraint.index_tuple[i] == slice(None, None, None): + pass + + else: + raise NotImplementedError('Method not yet implemented') + + # append the remaining dimensions to the right location + dim_index = 0 + for i in range(len(resulting_tensor_dims)): + if resulting_tensor_dims[i] is None: + resulting_tensor_dims[i] = dims[dim_index] + dim_index += 1 + + # check if the index is valid + is_valid_index = valid_index_tensor(constraint.index_tuple, dims) + + # check if the resulting tensor is within bounds + if len(resulting_tensor_dims) > 4: + return F(), counter + + else: + constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index] + return Conj(constraints), counter + + +@register_transformation_rule(BinConstraintT) +def generate_binconstraint_t(constraint, counter): + """ + Transform binary constraints for tensors + """ + + # precision constraints + if constraint.op == op_precision: + if constraint.lhs == Dyn: + return T(), counter + elif isinstance(constraint.lhs, TensorType): + is_fully_static = all(d != Dyn for d in constraint.lhs.__args__) + if is_fully_static: + return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter + else: + new_dims = [] + + for _ in range(len(constraint.lhs.__args__)): + dim, counter = gen_dvar(counter) + new_dims.append(dim) + + new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for + new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ + [BinConstraintD(1, new_dim, op_leq) for + new_dim in new_dims] + return Conj(new_dim_constraints), counter + + # matching + elif constraint.op == op_matching: + assert isinstance(constraint.rhs, TensorType) + d1 = constraint.rhs.__args__[0] + d2 = constraint.rhs.__args__[1] + d3 = constraint.rhs.__args__[2] + d4 = constraint.rhs.__args__[3] + + conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq)] + return Disj([Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter + + elif constraint.op == op_consistency: + c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) + [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) + + return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter + + elif constraint.op == op_leq: + assert isinstance(constraint.rhs, int) + disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] + for i in range(1, constraint.rhs + 1): + dims = [] + for j in range(1, i + 1): + dim_var, counter = gen_dvar(counter) + dims.append(dim_var) + disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) + return Disj(disj), counter + else: + return constraint, counter + + +@register_transformation_rule(BinConstraintD) +def generate_binconstraint_d(constraint, counter): + """ + Transform binary constraints for dimensions + """ + if constraint.op == op_precision: + if isinstance(constraint.lhs, int): + return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter + elif constraint.lhs == Dyn: + return T(), counter + + elif constraint.op == op_consistency: + return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter + + else: + return constraint, counter + + +@register_transformation_rule(Conj) +def generate_conj(constraint, counter): + """ + Transform conjunctions + """ + new = [] + for c in constraint.conjucts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Conj(new), counter + + +@register_transformation_rule(Disj) +def generate_disj(constraint, counter): + """ + Transform disjunctions + """ + new = [] + for c in constraint.disjuncts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Disj(new), counter + + +@register_transformation_rule(TGreatestUpperBound) +def generate_gub(constraint, counter): + """ + Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound + on dimensions + """ + c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) + + [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) + + return Disj([c1, c2, c3, c4, c5]), counter + + +@register_transformation_rule(DGreatestUpperBound) +def generate_d_gub(constraint, counter): + """ + Transform greatest upper bound for dimensions into equality constraints + """ + c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) + c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + return Disj([c1, c2, c3]), counter + + +@register_transformation_rule(CalcConv) +def generate_calc_conv(constraint, counter): + d, counter = gen_tensor_dims(4, counter) + conv_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the convolution result is a tensor of size 4 + c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) + + # the second dimension of the output is equal to the output channels + c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) + + # the input corresponds to the output in the first dimension of the convolution + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcMaxPool) +def generate_calc_maxpool(constraint, counter): + """ + Transform maxpool constraints + """ + d, counter = gen_tensor_dims(4, counter) + maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the maxpool result is a tensor of size 4 + c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) + + # the input corresponds to the output in the first and second dimension of maxpool + c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq)]) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcProduct) +def generate_calc_product(constraint, counter): + """ + Transform flatten constraints + """ + start = constraint.start + end = constraint.end + dims = constraint.dims_to_flatten + flattened = constraint.flattened + n = len(constraint.dims_to_flatten) + + # this will be evaluated right here + boundary_check = (0 <= start and start < end and end <= n) + + c_boundary = T() if boundary_check else F() + + lhs = dims[0:start] + rhs = dims[end:] + mid = dims[start:end] + + all_possibilities = generate_all_int_dyn_dim_possibilities(mid) + + all_constraints = [] + + for p in all_possibilities: + p = list(p) + # this tells us there is a dynamic variable + contains_dyn = not all(constraint.op == op_neq for constraint in p) + if contains_dyn: + mid_var = [Dyn] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) + else: + new_var, counter = gen_dvar(counter) + mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) + mid_var = [new_var] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) + + return Conj([Disj(all_constraints), c_boundary]), counter + + +@register_transformation_rule(CanReshape) +def generate_reshape(constraint, counter): + """ + Transform reshape constraints + """ + d, counter = gen_tensor_dims(4, counter) + + d1 = d[0] + d2 = d[1] + d3 = d[2] + d4 = d[3] + + target = constraint.target.__args__ + + is_fully_static = all(d != Dyn for d in target) + + # dynamic tensor + c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) + c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) + c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) + c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) + c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) + + d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) + d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) + + d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) + d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) + + d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + nat_d1 = BinConstraintD(0, d1, op_leq) + nat_d2 = BinConstraintD(0, d2, op_leq) + nat_d3 = BinConstraintD(0, d3, op_leq) + nat_d4 = BinConstraintD(0, d4, op_leq) + + if is_fully_static: + # size 1 tensor + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + BinConstraintD(d1, Prod(target), op_eq)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # size 2 tensor + all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) + + # size 3 tensor + all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) + + # size 4 tensor + all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + # then there must be exactly one occurrence of dyn + else: + new_target = [] + + for n in target: + if n != Dyn: + new_target.append(n) + + # tensor 1 + c3_tensor1 = Disj([d1_eq_dyn, + (Conj([d1_neq_dyn, + is_dim_div_by_target(new_target, d1)]))]) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # tensor 2 + c21 = Disj([d1_eq_dyn, d2_eq_dyn]) + c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) + all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) + + # tensor 3 + c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) + c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) + all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) + + # tensor 4 + c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) + c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) + all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) + + return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), + nat_d1, nat_d2, nat_d3, nat_d4]), counter + + +@register_transformation_rule(ApplyBroadcasting) +def generate_broadcasting(constraint, counter): + """ + Transform broadcasting constraints + """ + e11, e12 = constraint.res1, constraint.res2 + e1, e2 = constraint.input1, constraint.input2 + + e1_dyn = BinConstraintT(e1, Dyn, op_eq) + e2_dyn = BinConstraintT(e2, Dyn, op_eq) + + # Introduce dimensions + e1_equal_e11 = BinConstraintT(e1, e11, op_eq) + e2_equal_e12 = BinConstraintT(e2, e12, op_eq) + + # dyn possibility + e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) + e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) + + # tensor possibility + # generate dimensions to create tensors of size 1 + final_tensor_1_constraint, _, _, nat_dims_1, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) + + # generate dimensions to create tensors of size 2 + final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ + final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ + final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ + final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ + gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj([ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2 + ]) + + return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter + + +def transform_constraint(constraint: Constraint, counter: int): + """ + Transforms a constraint into a simpler constraint. + Ex: precision and consistency are transformed to equality + Args: + constraint: constraint to be transformed + counter: for variable tracking + + Returns: Constraint + + """ + if type(constraint) in _TRANSFORMATION_RULES: + return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) + + else: + return constraint, counter + + + + +def calc_last_two_dims(constraint, d: List[DVar]): + """ + Generates constraints for the last two dimensions of a convolution or a maxpool output + Args: + constraint: CalcConv or CalcMaxPool + d: The list of output dimensions + + Returns: Constraints for calculating the last two dimensions of the output + + """ + + assert isinstance(constraint, (CalcConv, CalcMaxPool)) + + b3 = constraint.matching_constraint[2] + b4 = constraint.matching_constraint[3] + + b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) + b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) + + d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) + d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) + + # transform parameters into tuples incase they are not already + padding = (constraint.padding, constraint.padding) \ + if isinstance(constraint.padding, int) else constraint.padding + kernel = (constraint.kernel, constraint.kernel) \ + if isinstance(constraint.kernel, int) else constraint.kernel + stride = (constraint.stride, constraint.stride) \ + if isinstance(constraint.stride, int) else constraint.stride + dilation = (constraint.dilation, constraint.dilation) \ + if isinstance(constraint.dilation, int) else constraint.dilation + + f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) + f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) + f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) + f4 = BinConstraintD(f3, 1, op_add) + + c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) + + f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) + f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) + f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) + f44 = BinConstraintD(f33, 1, op_add) + + c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) + + return c4, c5 + + +def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): + """ + Generate all possibilities of being equal or not equal to dyn for my_list + Args: + my_list: List of tensor dimensions + + Returns: A list of a list of constraints. Each list of constraints corresponds to + one possibility about the values of the dimension variables + """ + # generate all possibilities of being equal or not equal to dyn for my_list + eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] + neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] + d_possibilities = [] + + for i in zip(eq_possibilities, neq_possibilities): + d_possibilities.append(list(i)) + all_possibilities = list(itertools.product(*d_possibilities)) + return all_possibilities + + +def is_target_div_by_dim(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the target dimensions are divisible by the input dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) + + +def is_dim_div_by_target(target: List[int], dim: List[DVar]): + """ + Generate constraints to check if the input dimensions is divisible by the target dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) + + +def gen_all_reshape_possibilities(list_of_dims, target): + """ + Consider all possibilities what the input dimensions could be (number or dynamic) + Then generate the appropriate constraints using multiplication or mod depending on the possibility + The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn + for the input. Target is fixed because at most one dimension could be dyn. + We have different cases for this. + + Args: + list_of_dims: The input list of dimensions + target: The tensor we want to reshape to + + Returns: A disjunction of transformed reshape constraints + + """ + all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) + + all_constraints = [] + + for p in all_possibilities: + to_multiply = [] + + p = list(p) + + for constraint in p: + assert isinstance(constraint, BinConstraintD) + if constraint.op == op_neq: + to_multiply.append(constraint.lhs) + + if not to_multiply: + all_constraints.append(Conj(p)) + + elif len(to_multiply) < len(list_of_dims): + all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) + else: + all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), + Prod(target), op_eq)])) + + return Disj(all_constraints) + + +def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): + """ + Apply broadcasting to the 'index' dimension of tensor_input1. + Args: + tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 + tensor_input2: represents the second input + res1: broadcasted result 1 + res2: broadcasted result 2 + index: the index to broadcast + padding: If padding was used, then tensor_input1[index] does not exist + + Returns: + + """ + if tensor_input1[index] is None: + assert padding + + + if not padding: + # then the inputs are the same length so they all have dimensions at "index" + return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + else: + # we don't set the input dimension to 1, since it doesn't exist. + return Conj([BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + + +def apply_padding(e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], + counter: int): + """ + We are considering the possibility where one input has less dimensions than + another input, so we apply padding to the broadcasted results + + Args: + e1_var: Variable representing the first input where padding will be + e11: constraint of the form e11 = Tensortype[d1, ..., dn] + e2: constraint of the form e2 = Tensortype[d1, ..., dn] + e12: constraint of the form e11 = Tensortype[d1, ..., dn] + d2: Tensor variables for the second input + d11: Tensor variables for the broadcasted first input + d12: Tensor variables for the broadcasted second input + counter: variable tracking + + Returns: A new constraint whose goal is to apply padding to the broadcasted result + + """ + + res = [] + + # pad the shorter input with None so we can pass it to the broadcasting helper function + for i in range(1, len(d2)): + + d1, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) + + e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) + + simulate_padding = [None] * (len(d2) - i) + + assert len(simulate_padding + d1) == len(d2) + + broadcast_padding = [] + + # for every padding size, we also consider broadcasting + for j in range(len(d2) - i): + broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) + + # we consider the possibilities for broadcasting for every dimension. Since we already + # padded d1, we do not consider it while broadcasting + all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, + d2[(len(d2) - i):], + d11[(len(d2) - i):], + d12[(len(d2) - i):]) + # combine all constraints into a conjunction + c = Conj([e1, e11, e2, e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints + ]) + res.append(c) + + return Disj(res), counter + + +def no_broadcast_dim_with_index(d1: List[DVar], + d2: List[DVar], + d3: List[DVar], + d4: List[DVar], + i: int): + """ + Args: + d1: input 1 + d2: input 2 + d3: simulated broadcasting for input 1 + d4: simulated broadcasting for input 2 + i: the rank of the resulting tensor addition + + Returns: Constraints for when no broadcasting occurs + """ + return Conj([ + Disj([ + Conj([BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq)]), + + Conj([BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq)])]), + + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq)]) + + + +def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): + """ + Generate lists of DVar to represent tensor dimensions + Args: + num_tensors: the required number of tensors + dim_size: the number of dimensions for each tensor + counter: variable tracking + + Returns: A list of a list of tensor dimensions + + """ + res = [] + + for _ in range(num_tensors): + dims, counter = gen_tensor_dims(dim_size, counter) + res.append(dims) + + return res, counter + + +def create_equality_constraints_for_broadcasting(e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: List[DVar], + d2: List[DVar], + d11: List[DVar], + d12: List[DVar]): + """ + Create equality constraints for when no broadcasting occurs + Args: + e1: Input 1 + e2: Input 2 + e11: Broadcasted input 1 + e12: Broadcasted input 2 + d1: Variables that store dimensions for e1 + d2: Variables that store dimensions for e2 + d11: Variables that store dimensions for e11 + d12: Variables that store dimensions for e22 + + Returns: Four equality constraints + + """ + + e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) + e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) + e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) + e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) + return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] + + +def gen_consistency_constraints(constraint: Constraint, counter: int): + """ + Args: + constraint: Consistency constraint on tensors + counter: for variable tracking + + Returns: Equality and consistency constraints on dimensions + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + + [BinConstraintD(d1, d2, op_consistency) for + d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) + + all_constraints.append(c_tensor_i) + + return all_constraints, counter + + +def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): + """ + Args: + constraint: Greatest upper bound on tensors + counter: variable tracking + + Returns: A set of equality constraints and DGreatestUpperBound constraints + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + c = [] + dims1, counter = gen_tensor_dims(i, counter) + c1tensor = TensorType(dims1) + + dims2, counter = gen_tensor_dims(i, counter) + c2tensor = TensorType(dims2) + + dims3, counter = gen_tensor_dims(i, counter) + c3tensor = TensorType(dims3) + + c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq)] + \ + gen_nat_constraints(dims1 + dims2 + dims3) + + assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + for i in range(len(c3tensor.__args__)): + c.append(DGreatestUpperBound(c3tensor.__args__[i], + c1tensor.__args__[i], + c2tensor.__args__[i])) + + all_constraints.append(Conj(c)) + return all_constraints, counter + + +def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): + """ + Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. + We look at all combinations for all dimensions in d1 and d2 + Args: + d1: input1 dimensions + d2: input2 dimensions + d11: broadcasted input1 dimensions + d12: broadcasted input2 dimensions + + Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions + + """ + + size = len(d1) + + res2 = [] + + for i in range(size): + t1 = broadcast_dim(d1, d2, d11, d12, i) + t2 = broadcast_dim(d2, d1, d12, d11, i) + t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) + + res2.append(Disj([t1, t2, t3])) + + return Conj(res2) + + +def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): + """ + Simulates broadcasting on e1 and e2 and returns the results + respectively in e11 and e12. Because of gradual types, + e1 and e2 may not be equal. Similarly, e11 and e12 may not + be equal. e11 and e12 should be guaranteed to be consistent + as they represent the shapes of the tensors to be added after + broadcasting. + Args: + e1: TVar representing the type of input 1 + e2: TVar representing the type of input 2 + e11: TVar representing the representing broadcasted input 1 + e12: TVar representing the representing broadcasted input 2 + i: The rank of the resulting type of addition + counter: for variable tracking + + Returns: Simplified broadcasting constraints + + """ + dims, counter = gen_lists_of_dims(4, i, counter) + [d1, d2, d3, d4] = dims + nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) + + initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, + d1, d2, d3, d4) + + [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints + + # without padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) + + # with padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_padding_arg1, counter = \ + apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) + + final_tensor_constraint_padding_arg2, counter = \ + apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) + + return final_tensor_constraint_no_padding, \ + final_tensor_constraint_padding_arg1, \ + final_tensor_constraint_padding_arg2, nat_dims_i, counter diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..432cd570bebbfc37e2c5b5a167cb124a8907f39d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,14 @@ +op_add = '+' +op_sub = '-' +op_mul = '*' +op_div = '/' +op_eq = '=' +op_neq = '!=' +op_imp = '=>' +op_matching = '\u22b3' # (contains) +op_consistency = '~' +op_precision = '\u2291' # (square image of or equal to) +op_leq = '\u2264' # less-than or equal to +op_lt = '<' +op_gt = '>' +op_mod = '%' diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py new file mode 100644 index 0000000000000000000000000000000000000000..c8cf70006cd84c662f2f2ffd36e208b54bc1bbea --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -0,0 +1,349 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr +from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar +from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint +from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt +from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod +from torch.fx.tensor_type import TensorType, Dyn + +try: + import z3 # type: ignore[import] + from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D + HAS_Z3 = True + + def transform_to_z3(constraint, counter, dimension_dict): + if isinstance(constraint, Conj): + conjuncts = [] + for c in constraint.conjucts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + conjuncts.append(new_c) + return z3.And(conjuncts), counter + + elif isinstance(constraint, Disj): + disjuncts = [] + for c in constraint.disjuncts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + disjuncts.append(new_c) + return z3.Or(disjuncts), counter + + elif isinstance(constraint, T): + return True, counter + + elif isinstance(constraint, F): + return False, counter + + elif isinstance(constraint, BinConstraintT): + if constraint.op == op_eq: + lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + return (lhs == rhs), counter + + else: + raise NotImplementedError('Method not yet implemented') + + elif isinstance(constraint, BinConstraintD): + if constraint.op == op_eq: + + if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): + transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) + transformed_lhs = z3.Bool(constraint.lhs.c) + return transformed_lhs == transformed_rhs, counter + + elif is_dim(constraint.lhs) and is_dim(constraint.rhs): + # with dimension transformations we consider the encoding + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + else: + # then we have an algebraic expression which means that we disregard the + # first element of the encoding + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs == rhs, counter + + # The assumption here is that the LHS and RHS must be dimensions + elif constraint.op == op_neq: + assert is_dim(constraint.lhs) + assert is_dim(constraint.rhs) + lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + if constraint.rhs == Dyn or constraint.lhs == Dyn: + if constraint.rhs == Dyn: + return lhs.arg(0) == 1, counter + elif constraint.lhs == Dyn: + return rhs.arg(0) == 1, counter + + # if one of the instances is a number + elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): + if isinstance(constraint.lhs, int): + return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + elif isinstance(constraint.rhs, int): + return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + + else: + return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter + + + elif constraint.op == op_leq: + # if the dimensions are not dyn, this will come into effect + # there would have been another constraint specifying if a given dimension + # is dyn or not + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs <= rhs, counter + + elif constraint.op == op_gt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs > rhs, counter + + elif constraint.op == op_lt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + return lhs < rhs, counter + + else: + raise NotImplementedError('operation not yet implemented') + + else: + raise NotImplementedError('Operation not yet implemented') + + + def transform_var(tensor, counter, dimension_dict): + """ + Transforms tensor variables to a format understood by z3 + Args: + tensor: Tensor variable or a tensor type potentially with variable dimensions + Returns: Transformed variable to a z3 format + + """ + if isinstance(tensor, TensorType): + res = [] + for t in tensor.__args__: + transformed, counter = transform_dimension(t, counter, dimension_dict) + res.append(transformed) + + assert len(res) <= 4 + if len(tensor.__args__) == 1: + return tensor_type.tensor1(res[0]), counter + elif len(tensor.__args__) == 2: + return tensor_type.tensor2(res[0], res[1]), counter + elif len(tensor.__args__) == 3: + return tensor_type.tensor3(res[0], res[1], res[2]), counter + elif len(tensor.__args__) == 4: + return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + + elif tensor == Dyn: + return z3_dyn, counter + + elif isinstance(tensor, TVar): + return z3.Const(tensor.tvar, tensor_type), counter + + def transform_dimension(dimension, counter, dimension_dict): + """ + Takes a dimension variable or a number and transforms it to a tuple + according to our scheme + Args: + dimension: The dimension to be transformed + counter: variable tracking + + Returns: tuple and the current counter + + """ + if dimension == Dyn: + counter += 1 + return D(0, z3.Int(counter)), counter + elif isinstance(dimension, int): + return D(1, dimension), counter + elif isinstance(dimension, DVar): + if dimension.c in dimension_dict: + return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter + else: + counter += 1 + dimension_dict[dimension.c] = counter + return D(z3.Int(counter), z3.Int(dimension.c)), counter + + + def transform_algebraic_expression(expr, counter, dimension_dict): + """ + Transforms an algebraic expression to z3 format + Args: + expr: An expression is either a dimension variable or an algebraic-expression + + + Returns: the transformed expression + + """ + assert is_algebraic_expression(expr) or is_dim(expr) + + if is_dim(expr): + transformed, counter = transform_dimension(expr, counter, dimension_dict) + return transformed.arg(1), counter + + elif isinstance(expr, Prod): + + dims = [] + for dim in expr.products: + assert is_dim(dim) + d, counter = transform_dimension(dim, counter, dimension_dict) + dims.append(d.arg(1)) + return z3.Product(dims), counter + + elif is_algebraic_expression(expr): + + lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) + rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) + + if expr.op == op_sub: + c = lhs - rhs + + elif expr.op == op_add: + c = lhs + rhs + + elif expr.op == op_div: + c = lhs / rhs + + elif expr.op == op_mul: + c = lhs * rhs + + elif expr.op == op_mod: + c = lhs % rhs + + else: + raise NotImplementedError('operation not yet implemented') + + return c, counter + + else: + raise RuntimeError + + + def transform_all_constraints(traced, counter=0): + """ + Given a trace, generates constraints and transforms them to z3 format + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(counter) + + # print(new_constraints.conjucts[0]) + # print(*new_constraints.conjucts, sep='\n') + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + # print(new_constraints) + # print(new_constraints.conjucts) + # new_constraints.conjucts = new_constraints.conjucts[:-1] + # print(*new_constraints.conjucts, sep='\n') + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + # print(transformed) + return transformed + + def iterate_till_fixed_point(constraints, counter): + """ + Transform constraints till reaching a fixed point + """ + old_c = None + while old_c != constraints: + old_c = constraints + constraints, counter = transform_constraint(constraints, counter) + return constraints, counter + + def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + """ + Takes a node and a graph and generates two sets of constraints. + One set constraints the node's constraints and another set + constraints the negation of the node's constraints + Args: + tracer_root: the root for getting the module instances + graph: the graph so far in the tracing process + node: node that represents a conditional + counter: variable tracking + + Returns: Two sets of constraints. One with a conjunction with the + the conditional constraint and the other with a conjunction with + its negation. + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(tracer_root, graph) + new_constraints, counter = generator.generate_constraints(counter) + + condition_constraint = new_constraints.conjucts[-1] + + # we know the constraint is a conjunction where the last constraint is about the conditional + # so remove the last constraint + new_constraints.conjucts = new_constraints.conjucts[:-1] + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + + + # since the function returns a list of one element, we get the first element + # we are only interested in the RHS in this case because the LHS just stores + # the result + + # we make sure the constraint is of the form: + # c = b where b is a boolean expression + # and we consider b (constraint.rhs) for transformation + assert isinstance(condition_constraint.lhs, BVar) + assert is_bool_expr(condition_constraint.rhs) + condition_constraint_rhs = condition_constraint.rhs + + # transform the condition constraint + condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + + transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) + + negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) + + return z3.And([transformed, transformed_condition_constraint]), \ + z3.And([transformed, negation_transformed_condition_constraint]) + + + def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): + """ + Given an IR and a node representing a conditional, evaluate the conditional + and its negation + Args: + tracer_root: Tracer root for module instances + node: The node to be evaluated + + Returns: the results of evaluating the condition and the negation with + the rest of the constraints + + """ + + transformed_positive, transformed_negative = \ + transform_all_constraints_trace_time(tracer_root, graph, node, counter) + + s = z3.Solver() + s.add(transformed_positive) + if user_constraints is not None: + s.add(user_constraints) + condition = s.check() + + s = z3.Solver() + s.add(transformed_negative) + if user_constraints is not None: + s.add(user_constraints) + negation = s.check() + return condition, negation + +except ImportError: + HAS_Z3 = False diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py new file mode 100644 index 0000000000000000000000000000000000000000..99f94609f2650b6642bdce586907f757b032409b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py @@ -0,0 +1,53 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ + BVar +from torch.fx.experimental.migrate_gradual_types.operation import op_leq + + +def gen_tvar(curr): + """ + Generate a tensor variable + :param curr: The current counter + :return: a tensor variable and the updated counter + """ + curr += 1 + return TVar(curr), curr + + +def gen_dvar(curr): + """ + Generate a dimension variable + :param curr: the current counter + :return: a dimension variable and an updated counter + """ + curr += 1 + return DVar(curr), curr + +def gen_bvar(curr): + """ + Generate a boolean variable + :param curr: the current counter + :return: a boolean variable and an updated counter + """ + curr += 1 + return BVar(curr), curr + +def gen_tensor_dims(n, curr): + """ + Generate a list of tensor dimensions + :param n: the number of dimensions + :param curr: the current counter + :return: a list of dimension variables and an updated counter + """ + dims = [] + for _ in range(n): + dvar, curr = gen_dvar(curr) + dims.append(dvar) + return dims, curr + + +def gen_nat_constraints(list_of_dims): + """ + Generate natural number constraints for dimensions + """ + return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000000000000000000000000000000..897a79d5697573a51f5886d5e9965a98e2c4cf6a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,29 @@ +try: + import z3 # type: ignore[import] + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort('Dyn') + dyn_type = z3.Const('dyn', dyn) + + # dimension + dim = z3.Datatype('dim') + dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype('TensorType') + tensor_type.declare('Dyn', ('dyn', dyn)) + tensor_type.declare('tensor1', ('0', dim)) + tensor_type.declare('tensor2', ('0', dim), ('1', dim)) + tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) + tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31446d0e61253d7f722a3235e6e4c5788b4b01ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py @@ -0,0 +1,4 @@ +# mypy: disable-error-code=attr-defined +from .core import unify, reify # noqa: F403 +from .more import unifiable # noqa: F403 +from .variable import var, isvar, vars, variables, Var # noqa: F403 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f06ec23eeed2ca428a66efded65a59ef3ea34ec1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4fd897e5aff24d11f9c8d6d4e014718e5439592 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70ac3324329f1922176e49e08d1ab9b12811b02a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088c17118f409d54659f9dfc51d4f8981d55e0c1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4892d3036ffe8e12b5d7f7ee0fadcbc73a41129 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a2a5e2e224c9885da2b4d01389375782e274209 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ff6197c215daab76246779fa7d98ecb7de1fda Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aac6efc513d532b3dd2bd93b0a2f4f0a2e74a931 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py new file mode 100644 index 0000000000000000000000000000000000000000..0893c385bbc9ae5506069913d387689853f79d55 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterator # type: ignore[import] +from functools import partial + +from .unification_tools import assoc # type: ignore[import] +from .utils import transitive_get as walk +from .variable import isvar +from .dispatch import dispatch + +__all__ = ["reify", "unify"] + +############### +# Reification # +############### + +@dispatch(Iterator, dict) +def _reify(t, s): + return map(partial(reify, s=s), t) + # return (reify(arg, s) for arg in t) +_reify + +@dispatch(tuple, dict) # type: ignore[no-redef] +def _reify(t, s): + return tuple(reify(iter(t), s)) +_reify + +@dispatch(list, dict) # type: ignore[no-redef] +def _reify(t, s): + return list(reify(iter(t), s)) +_reify + +@dispatch(dict, dict) # type: ignore[no-redef] +def _reify(d, s): + return {k: reify(v, s) for k, v in d.items()} +_reify + +@dispatch(object, dict) # type: ignore[no-redef] +def _reify(o, s): + return o # catch all, just return the object + +def reify(e, s): + """ Replace variables of expression with substitution + >>> # xdoctest: +SKIP + >>> x, y = var(), var() + >>> e = (1, x, (3, y)) + >>> s = {x: 2, y: 4} + >>> reify(e, s) + (1, 2, (3, 4)) + >>> e = {1: x, 3: (y, 5)} + >>> reify(e, s) + {1: 2, 3: (4, 5)} + """ + if isvar(e): + return reify(s[e], s) if e in s else e + return _reify(e, s) + +############### +# Unification # +############### + +seq = tuple, list, Iterator + +@dispatch(seq, seq, dict) +def _unify(u, v, s): + if len(u) != len(v): + return False + for uu, vv in zip(u, v): # avoiding recursion + s = unify(uu, vv, s) + if s is False: + return False + return s +# +# @dispatch((set, frozenset), (set, frozenset), dict) +# def _unify(u, v, s): +# i = u & v +# u = u - i +# v = v - i +# return _unify(sorted(u), sorted(v), s) +# +# +# @dispatch(dict, dict, dict) +# def _unify(u, v, s): +# if len(u) != len(v): +# return False +# for key, uval in iteritems(u): +# if key not in v: +# return False +# s = unify(uval, v[key], s) +# if s is False: +# return False +# return s +# +# +# @dispatch(object, object, dict) +# def _unify(u, v, s): +# return False # catch all + + +@dispatch(object, object, dict) +def unify(u, v, s): # no check at the moment + """ Find substitution so that u == v while satisfying s + >>> x = var('x') + >>> unify((1, x), (1, 2), {}) + {~x: 2} + """ + u = walk(u, s) + v = walk(v, s) + if u == v: + return s + if isvar(u): + return assoc(s, u, v) + if isvar(v): + return assoc(s, v, u) + return _unify(u, v, s) +unify + +@dispatch(object, object) # type: ignore[no-redef] +def unify(u, v): + return unify(u, v, {}) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..93039ce75070fec8da52d03067d5c0b851a79b50 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py @@ -0,0 +1,6 @@ +from functools import partial +from .multipledispatch import dispatch # type: ignore[import] + +namespace = {} # type: ignore[var-annotated] + +dispatch = partial(dispatch, namespace=namespace) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py new file mode 100644 index 0000000000000000000000000000000000000000..96583ef324ded18eb93e6cea52fcc12d4fc16f3f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs +from .core import unify, reify # type: ignore[attr-defined] +from .variable import isvar +from .utils import _toposort, freeze +from .unification_tools import groupby, first # type: ignore[import] + + +class Dispatcher: + def __init__(self, name): + self.name = name + self.funcs = {} + self.ordering = [] + + def add(self, signature, func): + self.funcs[freeze(signature)] = func + self.ordering = ordering(self.funcs) + + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + return func(*args, **kwargs) + + def resolve(self, args): + n = len(args) + for signature in self.ordering: + if len(signature) != n: + continue + s = unify(freeze(args), signature) + if s is not False: + result = self.funcs[signature] + return result, s + raise NotImplementedError("No match found. \nKnown matches: " + + str(self.ordering) + "\nInput: " + str(args)) + + def register(self, *signature): + def _(func): + self.add(signature, func) + return self + return _ + + +class VarDispatcher(Dispatcher): + """ A dispatcher that calls functions with variable names + >>> # xdoctest: +SKIP + >>> d = VarDispatcher('d') + >>> x = var('x') + >>> @d.register('inc', x) + ... def f(x): + ... return x + 1 + >>> @d.register('double', x) + ... def f(x): + ... return x * 2 + >>> d('inc', 10) + 11 + >>> d('double', 10) + 20 + """ + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + d = {k.token: v for k, v in s.items()} + return func(**d) + + +global_namespace = {} # type: ignore[var-annotated] + + +def match(*signature, **kwargs): + namespace = kwargs.get('namespace', global_namespace) + dispatcher = kwargs.get('Dispatcher', Dispatcher) + + def _(func): + name = func.__name__ + + if name not in namespace: + namespace[name] = dispatcher(name) + d = namespace[name] + + d.add(signature, func) + + return d + return _ + + +def supercedes(a, b): + """ ``a`` is a more specific match than ``b`` """ + if isvar(b) and not isvar(a): + return True + s = unify(a, b) + if s is False: + return False + s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} + if reify(a, s) == a: + return True + if reify(b, s) == b: + return False + + +# Taken from multipledispatch +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +# Taken from multipledispatch +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(first, edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] + return _toposort(edges) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py new file mode 100644 index 0000000000000000000000000000000000000000..2228448a71a1fd19647a972be0a4f44d8c1c9f54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +from .core import unify, reify # type: ignore[attr-defined] +from .dispatch import dispatch + + +def unifiable(cls): + """ Register standard unify and reify operations on class + This uses the type and __dict__ or __slots__ attributes to define the + nature of the term + See Also: + >>> # xdoctest: +SKIP + >>> class A(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + >>> unifiable(A) + + >>> x = var('x') + >>> a = A(1, 2) + >>> b = A(1, x) + >>> unify(a, b, {}) + {~x: 2} + """ + _unify.add((cls, cls, dict), unify_object) + _reify.add((cls, dict), reify_object) + + return cls + + +######### +# Reify # +######### + + +def reify_object(o, s): + """ Reify a Python object with a substitution + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... def __str__(self): + ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) + >>> x = var('x') + >>> f = Foo(1, x) + >>> print(f) + Foo(1, ~x) + >>> print(reify_object(f, {x: 2})) + Foo(1, 2) + """ + if hasattr(o, '__slots__'): + return _reify_object_slots(o, s) + else: + return _reify_object_dict(o, s) + + +def _reify_object_dict(o, s): + obj = object.__new__(type(o)) + d = reify(o.__dict__, s) + if d == o.__dict__: + return o + obj.__dict__.update(d) + return obj + + +def _reify_object_slots(o, s): + attrs = [getattr(o, attr) for attr in o.__slots__] + new_attrs = reify(attrs, s) + if attrs == new_attrs: + return o + else: + newobj = object.__new__(type(o)) + for slot, attr in zip(o.__slots__, new_attrs): + setattr(newobj, slot, attr) + return newobj + + +@dispatch(slice, dict) +def _reify(o, s): + """ Reify a Python ``slice`` object """ + return slice(*reify((o.start, o.stop, o.step), s)) + + +######### +# Unify # +######### + + +def unify_object(u, v, s): + """ Unify two Python objects + Unifies their type and ``__dict__`` attributes + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... def __str__(self): + ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) + >>> x = var('x') + >>> f = Foo(1, x) + >>> g = Foo(1, 2) + >>> unify_object(f, g, {}) + {~x: 2} + """ + if type(u) != type(v): + return False + if hasattr(u, '__slots__'): + return unify([getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s) + else: + return unify(u.__dict__, v.__dict__, s) + + +@dispatch(slice, slice, dict) +def _unify(u, v, s): + """ Unify a Python ``slice`` object """ + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0295af0ea6b6b92836e034c1d28cfdf69b1d3ba --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -0,0 +1,3 @@ +from .core import dispatch +from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, + MDNotImplementedError) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653bb2b1176e49d9dbdc80b25fbd58e0fd601a2b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef357accdfff0a1e1bdcd12e1d61c35350300092 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f12d09c824b00b9d2301599f162caf7aec80dca9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd2318f59b12d6e287ec1a8318520d346aa4b26 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1def5e43f535075835a3e0ad053ed64a6e1e46e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1284dfb8a398a4cc08a818471ec3d526deb9cae Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..7187330ead257a30d3d18504ec4787d005a41cca --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +from .utils import _toposort, groupby +from .variadic import isvariadic +import operator + +__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", + "edge", "ordering"] + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """ A is consistent and strictly more specific than B """ + if len(a) < len(b): + # only case is if a is empty and b is variadic + return not a and len(b) == 1 and isvariadic(b[-1]) + elif len(a) == len(b): + return all(map(issubclass, a, b)) + else: + # len(a) > len(b) + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not (isvariadic(cur_a) or isvariadic(cur_b)): + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + assert p1 == len(a) - 1 + return p2 == len(b) - 1 and issubclass(cur_a, cur_b) + elif isvariadic(cur_b): + assert p2 == len(b) - 1 + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + return p2 == len(b) - 1 and p1 == len(a) + + +def consistent(a, b): + """ It is possible for an argument list to satisfy both A and B """ + + # Need to check for empty args + if not a: + return not b or isvariadic(b[0]) + if not b: + return not a or isvariadic(a[0]) + + # Non-empty args check for mutual subclasses + if len(a) == len(b): + return all(issubclass(aa, bb) or issubclass(bb, aa) + for aa, bb in zip(a, b)) + else: + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): + return False + if not (isvariadic(cur_a) or isvariadic(cur_b)): + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + p2 += 1 + elif isvariadic(cur_b): + p1 += 1 + # We only need to check for variadic ends + # Variadic types are guaranteed to be the last element + return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined] + isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined] + + +def ambiguous(a, b): + """ A is consistent with B but neither is strictly more specific """ + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """ All signature pairs such that A is ambiguous with B """ + signatures = list(map(tuple, signatures)) + return {(a, b) for a in signatures for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) + for c in signatures)} + + +def super_signature(signatures): + """ A signature that would break ambiguities """ + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] + for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + # A either supercedes B and B does not supercede A or if B does then call + # tie_breaker + return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) + + +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(operator.itemgetter(0), edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] + return _toposort(edges) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5bdbc963014e8f76b1f1ac3624eab56bf3b274 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import inspect +import sys + +from .dispatcher import Dispatcher, MethodDispatcher + +global_namespace = {} # type: ignore[var-annotated] + +__all__ = ["dispatch", "ismethod"] + +def dispatch(*types, **kwargs): + """ Dispatch function on the types of the inputs + Supports dispatch on all non-keyword arguments. + Collects implementations based on the function name. Ignores namespaces. + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Example: + >>> # xdoctest: +SKIP + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> # xdoctest: +SKIP + >>> f(3) + 4 + >>> f(3.0) + 2.0 + >>> # Specify an isolated namespace with the namespace keyword argument + >>> my_namespace = {} + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + >>> # Dispatch on instance methods within classes + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... @dispatch(int) + ... def __init__(self, datum): + ... self.data = [datum] + >>> MyClass([1, 2, 3]).data + [1, 2, 3] + >>> MyClass(3).data + [3] + """ + namespace = kwargs.get('namespace', global_namespace) + + types = tuple(types) + + def _df(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] + name, # type: ignore[union-attr] + MethodDispatcher(name), + ) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func) + return dispatcher + return _df + + +def ismethod(func): + """ Is func a method? + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + if hasattr(inspect, "signature"): + signature = inspect.signature(func) + return signature.parameters.get('self', None) is not None + else: + if sys.version_info.major < 3: + spec = inspect.getargspec(func) # type: ignore[attr-defined] + else: + spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] + return spec and spec.args and spec.args[0] == 'self' diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d28201d04191b8ed269595e73ee05a41044a24 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -0,0 +1,427 @@ +# mypy: allow-untyped-defs +from warnings import warn +import inspect +from typing_extensions import deprecated +from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning +from .utils import expand_tuples +from .variadic import Variadic, isvariadic +import itertools as itl + +__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", + "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] + +class MDNotImplementedError(NotImplementedError): + """ A NotImplementedError for multiple dispatch """ + + +def ambiguity_warn(dispatcher, ambiguities): + """ Raise warning when ambiguity is detected + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +@deprecated( + "`halt_ordering` is deprecated, you can safely remove this call.", + category=FutureWarning, +) +def halt_ordering(): + """Deprecated interface to temporarily disable ordering.""" + + +@deprecated( + "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " + "you should call the `reorder()` method on each dispatcher.", + category=FutureWarning, +) +def restart_ordering(on_ambiguity=ambiguity_warn): + """Deprecated interface to temporarily resume ordering.""" + + +def variadic_signature_matches_iter(types, full_signature): + """Check if a set of input types matches a variadic signature. + Notes + ----- + The algorithm is as follows: + Initialize the current signature to the first in the sequence + For each type in `types`: + If the current signature is variadic + If the type matches the signature + yield True + Else + Try to get the next signature + If no signatures are left we can't possibly have a match + so yield False + Else + yield True if the type matches the current signature + Get the next signature + """ + sigiter = iter(full_signature) + sig = next(sigiter) + for typ in types: + matches = issubclass(typ, sig) + yield matches + if not isvariadic(sig): + # we're not matching a variadic argument, so move to the next + # element in the signature + sig = next(sigiter) + else: + try: + sig = next(sigiter) + except StopIteration: + assert isvariadic(sig) + yield True + else: + # We have signature items left over, so all of our arguments + # haven't matched + yield False + + +def variadic_signature_matches(types, full_signature): + # No arguments always matches a variadic signature + assert full_signature + return all(variadic_signature_matches_iter(types, full_signature)) + + +class Dispatcher: + """ Dispatch methods based on type signature + Use ``dispatch`` to add implementations + Examples + -------- + >>> # xdoctest: +SKIP("bad import name") + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self.doc = doc + + self._cache = {} + + def register(self, *types, **kwargs): + """ register dispatcher with new implementation + >>> # xdoctest: +SKIP + >>> f = Dispatcher('f') + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + >>> f(1) + 2 + >>> f(1.0) + 0.0 + >>> f([1, 2, 3]) + [3, 2, 1] + """ + def _df(func): + self.add(types, func, **kwargs) # type: ignore[call-arg] + return func + return _df + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """ get annotations of function positional parameters + """ + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = (param for param in params + if param.kind in + (Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD)) + + annotations = tuple( + param.annotation + for param in params) + + if all(ann is not Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func): + """ Add new types/method pair to dispatcher + >>> # xdoctest: +SKIP + >>> D = Dispatcher('add') + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback + >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs + >>> # as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + new_signature = [] + + for index, typ in enumerate(signature, start=1): + if not isinstance(typ, (type, list)): + str_sig = ', '.join(c.__name__ if isinstance(c, type) + else str(c) for c in signature) + raise TypeError(f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}") + + # handle variadic signatures + if isinstance(typ, list): + if index != len(signature): + raise TypeError( + 'Variadic signature must be the last element' + ) + + if len(typ) != 1: + raise TypeError( + 'Variadic signature must contain exactly one element. ' + 'To use a variadic union type place the desired types ' + 'inside of a tuple, e.g., [(int, str)]' + ) + new_signature.append(Variadic[typ[0]]) + else: + new_signature.append(typ) + + self.funcs[tuple(new_signature)] = func + self._cache.clear() + + try: + del self._ordering + except AttributeError: + pass + + @property + def ordering(self): + try: + return self._ordering + except AttributeError: + return self.reorder() + + def reorder(self, on_ambiguity=ambiguity_warn): + self._ordering = od = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + return od + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError as e: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f'Could not find signature for {self.name}: <{str_signature(types)}>') from e + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError as e: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + + raise NotImplementedError( + "Matching functions for " + f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e + + def __str__(self): + return f"" + __repr__ = __str__ + + def dispatch(self, *types): + """Determine appropriate implementation for this type signature + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + >>> # xdoctest: +SKIP + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + >>> print(inc.dispatch(float)) + None + See Also: + ``multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + elif len(signature) and isvariadic(signature[-1]): + if variadic_signature_matches(types, signature): + result = self.funcs[signature] + yield result + + @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) + def resolve(self, types): + """ Determine appropriate implementation for this type signature + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + return self.dispatch(*types) + + def __getstate__(self): + return {'name': self.name, + 'funcs': self.funcs} + + def __setstate__(self, d): + self.name = d['name'] + self.funcs = d['funcs'] + self._ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): + docs = [f"Multiply dispatched method: {self.name}"] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = f'Inputs: <{str_signature(sig)}>\n' + s += '-' * len(s) + '\n' + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append('Other signatures:\n ' + '\n '.join(other)) + + return '\n\n'.join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """ Print docstring for the function corresponding to inputs """ + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """ Print source code for the function corresponding to inputs """ + print(self._source(*args)) + + +def source(func): + s = f'File: {inspect.getsourcefile(func)}\n\n' + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """ Dispatch methods based on type signature + See Also: + Dispatcher + """ + __slots__ = ('obj', 'cls') + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>') + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """ String representation of type signature + >>> str_signature((int, float)) + 'int, float' + """ + return ', '.join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """ The text for ambiguity warnings """ + text = f"\nAmbiguities exist in dispatched function {name}\n\n" + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + \ + ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + + f')\ndef {name}(...)' for s in amb]) + return text diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77702e8ccb7f4ff8e36d6d03bc29a2d16b23d80a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py @@ -0,0 +1,126 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict + +__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +def expand_tuples(L): + """ + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + >>> # Closely follows the wikipedia page [2] + >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + >>> # Communications of the ACM + >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = OrderedDict((k, set(val)) + for k, val in incoming_edges.items()) + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = OrderedDict() # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """ Group a collection by a key function + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + See Also: + ``countby`` + """ + + d = OrderedDict() # type: ignore[var-annotated] + for item in seq: + key = func(item) + if key not in d: + d[key] = [] + d[key].append(item) + return d + + +def typename(type): + """Get the name of `type`. + Parameters + ---------- + type : Union[Type, Tuple[Type]] + Returns + ------- + str + The name of `type` or a tuple of the names of the types in `type`. + Examples + -------- + >>> typename(int) + 'int' + >>> typename((int, float)) + '(int, float)' + """ + try: + return type.__name__ + except AttributeError: + if len(type) == 1: + return typename(*type) + return f"({', '.join(map(typename, type))})" diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py new file mode 100644 index 0000000000000000000000000000000000000000..49e546e1ea2672c9c7225e4c3fbc8301a53f299f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -0,0 +1,92 @@ +# mypy: allow-untyped-defs +from .utils import typename + +__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] + +class VariadicSignatureType(type): + # checking if subclass is a subclass of self + def __subclasscheck__(cls, subclass): + other_type = (subclass.variadic_type if isvariadic(subclass) + else (subclass,)) + return subclass is cls or all( + issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined] + ) + + def __eq__(cls, other): + """ + Return True if other has the same variadic type + Parameters + ---------- + other : object (type) + The object (type) to check + Returns + ------- + bool + Whether or not `other` is equal to `self` + """ + return (isvariadic(other) and + set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined] + + def __hash__(cls): + return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] + + +def isvariadic(obj): + """Check whether the type `obj` is variadic. + Parameters + ---------- + obj : type + The type to check + Returns + ------- + bool + Whether or not `obj` is variadic + Examples + -------- + >>> # xdoctest: +SKIP + >>> isvariadic(int) + False + >>> isvariadic(Variadic[int]) + True + """ + return isinstance(obj, VariadicSignatureType) + + +class VariadicSignatureMeta(type): + """A metaclass that overrides ``__getitem__`` on the class. This is used to + generate a new type for Variadic signatures. See the Variadic class for + examples of how this behaves. + """ + def __getitem__(cls, variadic_type): + if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): + raise ValueError("Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]") + + if not isinstance(variadic_type, tuple): + variadic_type = variadic_type, + return VariadicSignatureType( + f'Variadic[{typename(variadic_type)}]', + (), + dict(variadic_type=variadic_type, __slots__=()) + ) + + +class Variadic(metaclass=VariadicSignatureMeta): + """A class whose getitem method can be used to generate a new type + representing a specific variadic signature. + Examples + -------- + >>> # xdoctest: +SKIP + >>> Variadic[int] # any number of int arguments + + >>> Variadic[(int, str)] # any number of one of int or str arguments + + >>> issubclass(int, Variadic[int]) + True + >>> issubclass(int, Variadic[(int, str)]) + True + >>> issubclass(str, Variadic[(int, str)]) + True + >>> issubclass(float, Variadic[(int, str)]) + False + """ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d06d9bef771c4c3871edc869522fa3e2d26f5be8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py @@ -0,0 +1,396 @@ +# mypy: allow-untyped-defs +import collections +import operator +from functools import reduce +from collections.abc import Mapping + +__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap', + 'valfilter', 'keyfilter', 'itemfilter', + 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in'] + + +def _get_factory(f, kwargs): + factory = kwargs.pop('factory', dict) + if kwargs: + raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'") + return factory + + +def merge(*dicts, **kwargs): + """ Merge a collection of dictionaries + + >>> merge({1: 'one'}, {2: 'two'}) + {1: 'one', 2: 'two'} + + Later dictionaries have precedence + + >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) + {1: 2, 3: 3, 4: 4} + + See Also: + merge_with + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge, kwargs) + + rv = factory() + for d in dicts: + rv.update(d) + return rv + + +def merge_with(func, *dicts, **kwargs): + """ Merge dictionaries and apply function to combined values + + A key may occur in more than one dict, and all values mapped from the key + will be passed to the function as a list, such as func([val1, val2, ...]). + + >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) + {1: 11, 2: 22} + + >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP + {1: 1, 2: 2, 3: 30} + + See Also: + merge + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge_with, kwargs) + + result = factory() + for d in dicts: + for k, v in d.items(): + if k not in result: + result[k] = [v] + else: + result[k].append(v) + return valmap(func, result, factory) + + +def valmap(func, d, factory=dict): + """ Apply function to values of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> valmap(sum, bills) # doctest: +SKIP + {'Alice': 65, 'Bob': 45} + + See Also: + keymap + itemmap + """ + rv = factory() + rv.update(zip(d.keys(), map(func, d.values()))) + return rv + + +def keymap(func, d, factory=dict): + """ Apply function to keys of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> keymap(str.lower, bills) # doctest: +SKIP + {'alice': [20, 15, 30], 'bob': [10, 35]} + + See Also: + valmap + itemmap + """ + rv = factory() + rv.update(zip(map(func, d.keys()), d.values())) + return rv + + +def itemmap(func, d, factory=dict): + """ Apply function to items of dictionary + + >>> accountids = {"Alice": 10, "Bob": 20} + >>> itemmap(reversed, accountids) # doctest: +SKIP + {10: "Alice", 20: "Bob"} + + See Also: + keymap + valmap + """ + rv = factory() + rv.update(map(func, d.items())) + return rv + + +def valfilter(predicate, d, factory=dict): + """ Filter items in dictionary by value + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> valfilter(iseven, d) + {1: 2, 3: 4} + + See Also: + keyfilter + itemfilter + valmap + """ + rv = factory() + for k, v in d.items(): + if predicate(v): + rv[k] = v + return rv + + +def keyfilter(predicate, d, factory=dict): + """ Filter items in dictionary by key + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> keyfilter(iseven, d) + {2: 3, 4: 5} + + See Also: + valfilter + itemfilter + keymap + """ + rv = factory() + for k, v in d.items(): + if predicate(k): + rv[k] = v + return rv + + +def itemfilter(predicate, d, factory=dict): + """ Filter items in dictionary by item + + >>> def isvalid(item): + ... k, v = item + ... return k % 2 == 0 and v < 4 + + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> itemfilter(isvalid, d) + {2: 3} + + See Also: + keyfilter + valfilter + itemmap + """ + rv = factory() + for item in d.items(): + if predicate(item): + k, v = item + rv[k] = v + return rv + + +def assoc(d, key, value, factory=dict): + """ Return a new dict with new key value pair + + New dict has d[key] set to value. Does not modify the initial dictionary. + + >>> assoc({'x': 1}, 'x', 2) + {'x': 2} + >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP + {'x': 1, 'y': 3} + """ + d2 = factory() + d2.update(d) + d2[key] = value + return d2 + + +def dissoc(d, *keys, **kwargs): + """ Return a new dict with the given key(s) removed. + + New dict has d[key] deleted for each supplied key. + Does not modify the initial dictionary. + + >>> dissoc({'x': 1, 'y': 2}, 'y') + {'x': 1} + >>> dissoc({'x': 1, 'y': 2}, 'y', 'x') + {} + >>> dissoc({'x': 1}, 'y') # Ignores missing keys + {'x': 1} + """ + factory = _get_factory(dissoc, kwargs) + d2 = factory() + + if len(keys) < len(d) * .6: + d2.update(d) + for key in keys: + if key in d2: + del d2[key] + else: + remaining = set(d) + remaining.difference_update(keys) + for k in remaining: + d2[k] = d[k] + return d2 + + +def assoc_in(d, keys, value, factory=dict): + """ Return a new dict with new, potentially nested, key value pair + + >>> purchase = {'name': 'Alice', + ... 'order': {'items': ['Apple', 'Orange'], + ... 'costs': [0.50, 1.25]}, + ... 'credit card': '5555-1234-1234-1234'} + >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} + """ + return update_in(d, keys, lambda x: value, value, factory) + + +def update_in(d, keys, func, default=None, factory=dict): + """ Update value in a (potentially) nested dictionary + + inputs: + d - dictionary on which to operate + keys - list or tuple giving the location of the value to be changed in d + func - function to operate on that value + + If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the + original dictionary with v replaced by func(v), but does not mutate the + original dictionary. + + If k0 is not a key in d, update_in creates nested dictionaries to the depth + specified by the keys, with the innermost value set to func(default). + + >>> inc = lambda x: x + 1 + >>> update_in({'a': 0}, ['a'], inc) + {'a': 1} + + >>> transaction = {'name': 'Alice', + ... 'purchase': {'items': ['Apple', 'Orange'], + ... 'costs': [0.50, 1.25]}, + ... 'credit card': '5555-1234-1234-1234'} + >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} + + >>> # updating a value when k0 is not in d + >>> update_in({}, [1, 2, 3], str, default="bar") + {1: {2: {3: 'bar'}}} + >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) + {1: 'foo', 2: {3: {4: 1}}} + """ + ks = iter(keys) + k = next(ks) + + rv = inner = factory() + rv.update(d) + + for key in ks: + if k in d: + d = d[k] + dtemp = factory() + dtemp.update(d) + else: + d = dtemp = factory() + + inner[k] = inner = dtemp + k = key + + if k in d: + inner[k] = func(d[k]) + else: + inner[k] = func(default) + return rv + + +def get_in(keys, coll, default=None, no_default=False): + """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + + If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless + ``no_default`` is specified, then it raises KeyError or IndexError. + + ``get_in`` is a generalization of ``operator.getitem`` for nested data + structures such as dictionaries and lists. + + >>> transaction = {'name': 'Alice', + ... 'purchase': {'items': ['Apple', 'Orange'], + ... 'costs': [0.50, 1.25]}, + ... 'credit card': '5555-1234-1234-1234'} + >>> get_in(['purchase', 'items', 0], transaction) + 'Apple' + >>> get_in(['name'], transaction) + 'Alice' + >>> get_in(['purchase', 'total'], transaction) + >>> get_in(['purchase', 'items', 'apple'], transaction) + >>> get_in(['purchase', 'items', 10], transaction) + >>> get_in(['purchase', 'total'], transaction, 0) + 0 + >>> get_in(['y'], {}, no_default=True) + Traceback (most recent call last): + ... + KeyError: 'y' + + See Also: + itertoolz.get + operator.getitem + """ + try: + return reduce(operator.getitem, keys, coll) + except (KeyError, IndexError, TypeError): + if no_default: + raise + return default + + +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + elif index: + return operator.itemgetter(*index) + else: + return lambda x: () + else: + return operator.itemgetter(index) + + +def groupby(key, seq): + """ Group a collection by a key function + + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + Non-callable keys imply grouping on a member. + + >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'}, + ... {'name': 'Bob', 'gender': 'M'}, + ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP + {'F': [{'gender': 'F', 'name': 'Alice'}], + 'M': [{'gender': 'M', 'name': 'Bob'}, + {'gender': 'M', 'name': 'Charlie'}]} + + Not to be confused with ``itertools.groupby`` + + See Also: + countby + """ + if not callable(key): + key = getter(key) + d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] + for item in seq: + d[key(item)](item) + rv = {} + for k, v in d.items(): + rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] + return rv + + +def first(seq): + """ The first element in a sequence + + >>> first('ABC') + 'A' + """ + return next(iter(seq)) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..609fe59d43f45863e65d42ba9732a8bb095ebb11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py @@ -0,0 +1,106 @@ +# mypy: allow-untyped-defs +__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + + +def transitive_get(key, d): + """ Transitive dict.get + >>> d = {1: 2, 2: 3, 3: 4} + >>> d.get(1) + 2 + >>> transitive_get(1, d) + 4 + """ + while hashable(key) and key in d: + key = d[key] + return key + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> # xdoctest: +SKIP + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + Closely follows the wikipedia page [2] + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = ({v for v in edges if v not in incoming_edges}) + L = [] + + while S: + n = S.pop() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S.add(m) + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = {} # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +def xfail(func): + try: + func() + raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 + except Exception: + pass + + +def freeze(d): + """ Freeze container to hashable form + >>> freeze(1) + 1 + >>> freeze([1, 2]) + (1, 2) + >>> freeze({1: 2}) # doctest: +SKIP + frozenset([(1, 2)]) + """ + if isinstance(d, dict): + return frozenset(map(freeze, d.items())) + if isinstance(d, set): + return frozenset(map(freeze, d)) + if isinstance(d, (tuple, list)): + return tuple(map(freeze, d)) + return d diff --git a/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..66e97a3a766361cd10e586857104cd71118dcc96 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager +from .utils import hashable +from .dispatch import dispatch + +_global_logic_variables = set() # type: ignore[var-annotated] +_glv = _global_logic_variables + + +class Var: + """ Logic Variable """ + + _id = 1 + + def __new__(cls, *token): + if len(token) == 0: + token = f"_{Var._id}" # type: ignore[assignment] + Var._id += 1 + elif len(token) == 1: + token = token[0] + + obj = object.__new__(cls) + obj.token = token # type: ignore[attr-defined] + return obj + + def __str__(self): + return "~" + str(self.token) # type: ignore[attr-defined] + __repr__ = __str__ + + def __eq__(self, other): + return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined] + + def __hash__(self): + return hash((type(self), self.token)) # type: ignore[attr-defined] + + +def var(): + return lambda *args: Var(*args) + + +def vars(): + return lambda n: [var() for i in range(n)] + + +@dispatch(Var) +def isvar(v): + return True + +isvar + + +@dispatch(object) # type: ignore[no-redef] +def isvar(o): + return not not _glv and hashable(o) and o in _glv + + +@contextmanager +def variables(*variables): + """ + Context manager for logic variables + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> from __future__ import with_statement + >>> with variables(1): + ... print(isvar(1)) + True + >>> print(isvar(1)) + False + >>> # Normal approach + >>> from unification import unify + >>> x = var('x') + >>> unify(x, 1) + {~x: 1} + >>> # Context Manager approach + >>> with variables('x'): + ... print(unify('x', 1)) + {'x': 1} + """ + old_global_logic_variables = _global_logic_variables.copy() + _global_logic_variables.update(set(variables)) + try: + yield + finally: + _global_logic_variables.clear() + _global_logic_variables.update(old_global_logic_variables) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f83a2f248fcde4fd3947ee0ea82505a8bb7af976 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py @@ -0,0 +1,12 @@ +from . import graph_drawer +from . import graph_manipulation +from . import net_min_base +from . import operator_support +from . import param_fetch +from . import reinplace +from . import runtime_assert +from . import shape_prop +from . import split_module +from . import split_utils +from . import splitter_base +from . import tools_common diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..0399cef526205f8f82a0c53555bc16fdab67a550 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py @@ -0,0 +1,44 @@ +import operator + +import torch + + +def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: + """ + Annotate the type of getitem nodes, inferred from the type of sequence node. + If sequence node is not annotated with a type, do nothing. + Currently support getitem nodes from Tuple, List, and NamedTuple sequence node. + + This is helpful since annotations on local names within function are lost during FX transforms. + Adding back known type annotation for getitem nodes to improve jit scriptability. + + Args: + graph (Graph): The graph to be annotated + """ + for node in graph.nodes: + if node.target == operator.getitem: + sequence_node, index_node = node.args + if not sequence_node.type: + continue + # container types + if hasattr(sequence_node.type, "_name"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type._name == "Tuple": + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type._name == "List": + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # NamedTuple type + elif hasattr(sequence_node.type, "__annotations__"): + if sequence_node.type == torch.Tensor: + continue + sequence_node_field_types = sequence_node.type.__annotations__ + field_name = sequence_node.type._fields[index_node] + node.type = sequence_node_field_types[field_name] diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..437809ad1fbf1dfef498a9ead81e71f9723efd4a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c309de1c7e773b949892cd52bdb977250dcf611a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dec4dda60e4b88b4bf6d72afa34eefe5eb6c7d7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..577f445e7b316c61657625705f11f38daa1511a4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +from typing import Dict, Tuple, Any + +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + +from torch.fx import GraphModule, Graph +from torch.fx import Node + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950 + +inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + def get_aten_target(node): + if hasattr(node.target, 'overloadpacket'): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph + hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph + token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = {"target": n.target, "args": args, "args_spec": args_spec, + "kwargs": kwargs, "kwargs_spec": kwargs_spec} + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substitution happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..2b40207e0f8048cb98cc99c147b979d9791e8d01 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/fake_tensor_prop.py @@ -0,0 +1,70 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch.fx +from torch.fx import Node +from torch.fx.node import map_aggregate +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types + +__all__ = ['FakeTensorProp'] + +@compatibility(is_backward_compatible=False) +class FakeTensorProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record a fake tensor representing + the metadata for the node. Unlike ShapeProp, (1) this propagation + is cheap--it does the propagation with meta tensors which do not actually + store data, and (2) the fake tensors have much more fine grained information, + e.g., they have accurate alias information that can be consulted by looking + at the storages. + + Args: + module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. + """ + def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode + mode.epoch += 1 + mode.reset_nt_tensor_id_counter() + + def run_node(self, n: Node): + from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings + + result = super().run_node(n) + rebind_unbacked(self._mode.shape_env, n, result) + + def extract_val(obj): + if isinstance(obj, FakeTensor): + return snapshot_fake(obj) + elif isinstance(obj, torch.Tensor): + # TODO: How is it possible that we get a non fake tensor? We + # should be running under the mode... + return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True)) + elif isinstance(obj, py_sym_types): + return obj + else: + return None + + meta = map_aggregate(result, extract_val) + if meta is not None: + n.meta['val'] = meta + if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)): + n.meta["unbacked_bindings"] = symbol_to_path + + return result + + def propagate(self, *args): + fake_args = [ + self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a + for a in args + ] + return self.propagate_dont_convert_inputs(*fake_args) + + def propagate_dont_convert_inputs(self, *args): + with self._mode: + return super().run(*args) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..975b2b6171780e9705603196ef7596dc6ea24f55 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_drawer.py @@ -0,0 +1,443 @@ +# mypy: allow-untyped-defs + +import hashlib +from itertools import chain +from typing import Any, Dict, Optional, TYPE_CHECKING + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import _parse_stack_trace +from torch.fx.node import _format_arg, _get_qualified_name +from torch.fx.operator_schemas import normalize_function +from torch.fx.passes.shape_prop import TensorMetadata + + +try: + import pydot + + HAS_PYDOT = True +except ModuleNotFoundError: + HAS_PYDOT = False + pydot = None + + +__all__ = ["FxGraphDrawer"] + +_COLOR_MAP = { + "placeholder": '"AliceBlue"', + "call_module": "LemonChiffon1", + "get_param": "Yellow2", + "get_attr": "LightGrey", + "output": "PowderBlue", +} + +_HASH_COLOR_MAP = [ + "CadetBlue1", + "Coral", + "DarkOliveGreen1", + "DarkSeaGreen1", + "GhostWhite", + "Khaki1", + "LavenderBlush1", + "LightSkyBlue", + "MistyRose1", + "MistyRose2", + "PaleTurquoise2", + "PeachPuff1", + "Salmon", + "Thistle1", + "Thistle3", + "Wheat1", +] + +_WEIGHT_TEMPLATE = { + "fillcolor": "Salmon", + "style": '"filled,rounded"', + "fontcolor": "#000000", +} + +if HAS_PYDOT: + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + """ + Visualize a torch.fx.Graph with graphviz + Basic usage: + g = FxGraphDrawer(symbolic_traced, "resnet18") + g.get_dot_graph().write_svg("a.svg") + """ + + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + self._name = name + self.dot_graph_shape = ( + dot_graph_shape if dot_graph_shape is not None else "record" + ) + self.normalize_args = normalize_args + _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape + + self._dot_graphs = { + name: self._to_dot( + graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace + ) + } + + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + + leaf_node = self._get_leaf_node(graph_module, node) + + if not isinstance(leaf_node, torch.fx.GraphModule): + continue + + self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( + leaf_node, + f"{name}_{node.target}", + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + + def get_dot_graph(self, submod_name=None) -> pydot.Dot: + """ + Visualize a torch.fx.Graph with graphviz + Example: + >>> # xdoctest: +REQUIRES(module:pydot) + >>> # xdoctest: +REQUIRES(module:ubelt) + >>> # define module + >>> class MyModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.linear = torch.nn.Linear(4, 5) + >>> def forward(self, x): + >>> return self.linear(x).clamp(min=0.0, max=1.0) + >>> module = MyModule() + >>> # trace the module + >>> symbolic_traced = torch.fx.symbolic_trace(module) + >>> # setup output file + >>> import ubelt as ub + >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() + >>> fpath = dpath / 'linear.svg' + >>> # draw the graph + >>> g = FxGraphDrawer(symbolic_traced, "linear") + >>> g.get_dot_graph().write_svg(fpath) + """ + if submod_name is None: + return self.get_main_dot_graph() + else: + return self.get_submod_dot_graph(submod_name) + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] + + def get_submod_dot_graph(self, submod_name) -> pydot.Dot: + return self._dot_graphs[f"{self._name}_{submod_name}"] + + def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: + return self._dot_graphs + + def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: + + template = { + "shape": self.dot_graph_shape, + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + # Use a random color for each node; based on its name so it's stable. + target_name = node._pretty_print_target(node.target) + target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) + template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] + return template + + def _get_leaf_node( + self, module: torch.nn.Module, node: torch.fx.Node + ) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError( + str(py_obj) + " does not have attribute " + atom + "!" + ) + py_obj = getattr(py_obj, atom) + return py_obj + + def _typename(self, target: Any) -> str: + if isinstance(target, torch.nn.Module): + ret = torch.typename(target) + elif isinstance(target, str): + ret = target + else: + ret = _get_qualified_name(target) + + # Escape "{" and "}" to prevent dot files like: + # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc + # which triggers `Error: bad label format (...)` from dot + return ret.replace("{", r"\{").replace("}", r"\}") + + # shorten path to avoid drawing long boxes + # for full path = '/home/weif/pytorch/test.py' + # return short path = 'pytorch/test.py' + def _shorten_file_name( + self, + full_file_name: str, + truncate_to_last_n: int = 2, + ): + splits = full_file_name.split('/') + if len(splits) >= truncate_to_last_n: + return '/'.join(splits[-truncate_to_last_n:]) + return full_file_name + + + def _get_node_label( + self, + module: torch.fx.GraphModule, + node: torch.fx.Node, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> str: + def _get_str_for_args_kwargs(arg): + if isinstance(arg, tuple): + prefix, suffix = r"|args=(\l", r",\n)\l" + arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] + elif isinstance(arg, dict): + prefix, suffix = r"|kwargs={\l", r",\n}\l" + arg_strs_list = [ + f"{k}: {_format_arg(v, max_list_len=8)}" + for k, v in arg.items() + ] + else: # Fall back to nothing in unexpected case. + return "" + + # Strip out node names if requested. + if skip_node_names_in_args: + arg_strs_list = [a for a in arg_strs_list if "%" not in a] + if len(arg_strs_list) == 0: + return "" + arg_strs = prefix + r",\n".join(arg_strs_list) + suffix + if len(arg_strs_list) == 1: + arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") + return arg_strs.replace("{", r"\{").replace("}", r"\}") + + + label = "{" + f"name=%{node.name}|op_code={node.op}\n" + + if node.op == "call_module": + leaf_module = self._get_leaf_node(module, node) + label += r"\n" + self._typename(leaf_module) + r"\n|" + extra = "" + if hasattr(leaf_module, "__constants__"): + extra = r"\n".join( + [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] + ) + label += extra + r"\n" + else: + label += f"|target={self._typename(node.target)}" + r"\n" + if self.normalize_args: + try: + args, kwargs = normalize_function( # type: ignore[misc] + node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + ) + except Exception: + # Fallback to not normalizing if there's an exception. + # Some functions need overloads specified to normalize. + args, kwargs = node.args, node.kwargs + else: + args, kwargs = node.args, node.kwargs + if len(args) > 0: + label += _get_str_for_args_kwargs(args) + if len(kwargs) > 0: + label += _get_str_for_args_kwargs(kwargs) + label += f"|num_users={len(node.users)}" + r"\n" + + tensor_meta = node.meta.get('tensor_meta') + label += self._tensor_meta_to_label(tensor_meta) + + # for original fx graph + # print buf=buf0, n_origin=6 + buf_meta = node.meta.get('buf_meta', None) + if buf_meta is not None: + label += f"|buf={buf_meta.name}" + r"\n" + label += f"|n_origin={buf_meta.n_origin}" + r"\n" + + # for original fx graph + # print file:lineno code + if parse_stack_trace and node.stack_trace is not None: + parsed_stack_trace = _parse_stack_trace(node.stack_trace) + fname = self._shorten_file_name(parsed_stack_trace.file) + label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" + + + return label + "}" + + def _tensor_meta_to_label(self, tm) -> str: + if tm is None: + return "" + elif isinstance(tm, TensorMetadata): + return self._stringify_tensor_meta(tm) + elif isinstance(tm, list): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + elif isinstance(tm, dict): + result = "" + for v in tm.values(): + result += self._tensor_meta_to_label(v) + return result + elif isinstance(tm, tuple): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + else: + raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") + + def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: + result = "" + if not hasattr(tm, "dtype"): + print("tm", tm) + result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" + result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" + result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" + result += "|" + "stride" + "=" + str(tm.stride) + r"\n" + if tm.is_quantized: + assert tm.qparams is not None + assert "qscheme" in tm.qparams + qscheme = tm.qparams["qscheme"] + if qscheme in { + torch.per_tensor_affine, + torch.per_tensor_symmetric, + }: + result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + }: + result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" + else: + raise RuntimeError(f"Unsupported qscheme: {qscheme}") + result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" + return result + + def _get_tensor_label(self, t: torch.Tensor) -> str: + return str(t.dtype) + str(list(t.shape)) + r"\n" + + # when parse_stack_trace=True + # print file:lineno code + def _to_dot( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool, + ignore_parameters_and_buffers: bool, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> pydot.Dot: + """ + Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. + If ignore_parameters_and_buffers is True, the parameters and buffers + created with the module will not be added as nodes and edges. + """ + + # "TB" means top-to-bottom rank direction in layout + dot_graph = pydot.Dot(name, rankdir="TB") + + + buf_name_to_subgraph = {} + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + style = self._get_node_style(node) + dot_node = pydot.Node( + node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style + ) + + current_graph = dot_graph + + buf_meta = node.meta.get('buf_meta', None) + if buf_meta is not None and buf_meta.n_origin > 1: + buf_name = buf_meta.name + if buf_name not in buf_name_to_subgraph: + buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) + current_graph = buf_name_to_subgraph.get(buf_name) + + current_graph.add_node(dot_node) + + def get_module_params_or_buffers(): + for pname, ptensor in chain( + leaf_module.named_parameters(), leaf_module.named_buffers() + ): + pname1 = node.name + "." + pname + label1 = ( + pname1 + "|op_code=get_" + "parameter" + if isinstance(ptensor, torch.nn.Parameter) + else "buffer" + r"\l" + ) + dot_w_node = pydot.Node( + pname1, + label="{" + label1 + self._get_tensor_label(ptensor) + "}", + **_WEIGHT_TEMPLATE, + ) + dot_graph.add_node(dot_w_node) + dot_graph.add_edge(pydot.Edge(pname1, node.name)) + + if node.op == "call_module": + leaf_module = self._get_leaf_node(graph_module, node) + + if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): + get_module_params_or_buffers() + + for subgraph in buf_name_to_subgraph.values(): + subgraph.set('color', 'royalblue') + subgraph.set('penwidth', '2') + dot_graph.add_subgraph(subgraph) + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + for user in node.users: + dot_graph.add_edge(pydot.Edge(node.name, user.name)) + + return dot_graph + +else: + if not TYPE_CHECKING: + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' + 'pydot through your favorite Python package manager.') diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..36c59cb31af05c689826ff338d710605a68d9587 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_manipulation.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from typing import Any, Dict, List, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import ( + map_arg, + Node, + Target, +) +from torch.fx.passes.shape_prop import ShapeProp + +__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', + 'get_size_of_node'] + +@compatibility(is_backward_compatible=False) +def replace_target_nodes_with( + fx_module: GraphModule, + old_op: str, + old_target: Target, + new_op: str, + new_target: Target, +): + """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, + and updates them to match the new op code and target""" + new_graph = Graph() + val_map: Dict[Node, Node] = {} + for node in fx_module.graph.nodes: + if node.op == old_op and node.target == old_target: + args = map_arg(node.args, lambda n: val_map[n]) + kwargs = map_arg(node.kwargs, lambda n: val_map[n]) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + val_map[node] = new_graph.create_node( + new_op, new_target, args, kwargs, node.name + ) + else: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + fx_module.graph = new_graph + + +@compatibility(is_backward_compatible=False) +class size_bytes(NamedTuple): + output_size: int + total_size: int + + +@compatibility(is_backward_compatible=False) +def get_size_of_all_nodes( + fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None +) -> None: + """Given a fx graph module, update each node with its total size (weights + bias + output) + and its output_size(output). For a non-module node, the total size is the output size. + return total size""" + if args is not None: + # Mark shape and dtype for each node (node.shape and node.dtype) + ShapeProp(fx_module).propagate(*args) + # Calculate the total size of the whole fx graph + total_size_of_graph = 0.0 + for node in fx_module.graph.nodes: + if node.op == "output": + break + node.size_bytes = get_size_of_node(fx_module, node) + return + + +@compatibility(is_backward_compatible=False) +def get_tensor_meta(node: Node) -> Any: + tensor_meta = node.meta.get("tensor_meta") + + if not tensor_meta: + raise RuntimeError( + f"Node {node} has no tensor metadata associated with it! " + f"Check that shape propagation has run." + ) + + return tensor_meta + + +@compatibility(is_backward_compatible=False) +def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: + """Given a node with node.dtype and node.shape, return its total size and its output size. + total_size = weights + bias + output_size + """ + # Total num of elements + total_num_of_elems = 0 + # For a module, conside all parameters + if node.op == "call_module": + submodule_dict = dict(fx_module.named_modules()) + submodule = submodule_dict[node.target] + parameters = submodule.named_parameters() + # Parameters are named tuples + for name, p in parameters: + total_num_of_elems += p.numel() + # Don't forget the output size + # node.shape is the shape of this node's output + tensor_meta = get_tensor_meta(node) + output_elem = tensor_meta.shape.numel() + total_num_of_elems += output_elem + # Assume for now if it's quantized then it's qint8 or quint8 + if tensor_meta.is_quantized: + size_per_elem_bytes = torch._empty_affine_quantized( + [], dtype=tensor_meta.dtype + ).element_size() + else: + size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() + total_size = size_per_elem_bytes * total_num_of_elems + output_size = size_per_elem_bytes * output_elem + return size_bytes(output_size, total_size) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_transform_observer.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_transform_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..6390b7cee4954c1b1c8bf9fbc0040df77e146d12 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/graph_transform_observer.py @@ -0,0 +1,91 @@ +# mypy: allow-untyped-defs +import os +from typing import Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + +from .graph_drawer import FxGraphDrawer + + +__all__ = ["GraphTransformObserver"] + + +@compatibility(is_backward_compatible=False) +class GraphTransformObserver: + __pass_count = 0 + + def __init__(self, gm: GraphModule, passname: str, log_url: Optional[str] = None): + # If log_url is None, we don't log anything + self.log_url = log_url + if self.log_url is None: + return + GraphTransformObserver.__pass_count += 1 + self.gm = gm + self.passname = passname + + self.input_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + + @classmethod + def get_current_pass_count(cls): + return cls.__pass_count + + def __enter__(self): + if self.log_url is None or self.gm is None: + return self + + self.erased_nodes = set() + self.created_nodes = set() + self.gm._register_create_node_hook(self.on_node_creation) + self.gm._register_erase_node_hook(self.on_node_erase) + + return self + + def __exit__(self, type, value, tb): + if self.log_url is None or self.gm is None: + return + + self.gm._unregister_create_node_hook(self.on_node_creation) + self.gm._unregister_erase_node_hook(self.on_node_erase) + + if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: + for e in self.input_dot_graph.get_node_list(): + if e.get_name() in self.erased_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + self.input_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot", + ) + ) + + output_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + for e in output_dot_graph.get_node_list(): + if e.get_name() in self.created_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + output_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot", + ) + ) + + def on_node_creation(self, node): + self.created_nodes.add(node.name) + + def on_node_erase(self, node): + self.erased_nodes.add(node.name) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..657b6a93014f428eece18ec896136c81bc3949f3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__init__.py @@ -0,0 +1,2 @@ + +from . import pass_manager diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4661b6fc03d7fc3c9af5d6cb1b038e336d1413fa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..140ac4ea5348abec0429e07e26275a0bb64f77d0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e33c830408cadc2a97766efc26b41ac3fa79a34 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..271f90a7b75e8b63b9541d344a00476bf2d1e4ab --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-defs +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions +import collections +import itertools +import logging + +from copy import copy +from typing import Dict, Iterable, List, Optional, Sequence, Set + +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node, _get_qualified_name +from torch.fx.passes.operator_support import OperatorSupportBase + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +class Partition: + def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): + self.id = id + self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node): + self.nodes.update({node: None}) + + def remove_node(self, node: Node): + del self.nodes[node] + + def size(self): + return len(self.nodes) + +class _DependencyViewer: + def __init__(self, graph_module: GraphModule): + self.upstreams = collections.defaultdict(set) + self.downstreams = collections.defaultdict(set) + + for node in graph_module.graph.nodes: + for input_node in node.all_input_nodes: + # add input_node and input_node's upstream dependency + self.upstreams[node].add(input_node) + self.upstreams[node].update(self.upstreams[input_node]) + + for node in reversed(graph_module.graph.nodes): + for output_node in node.users: + # add output_node and output_node's downstream dependency + self.downstreams[node].add(output_node) + self.downstreams[node].update(self.downstreams[output_node]) + + def downstreams_of(self, node: Node) -> Set[Node]: + return self.downstreams[node] + + def upstreams_of(self, node: Node) -> Set[Node]: + return self.upstreams[node] + +class CapabilityBasedPartitioner: + + def __init__(self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] + self.allowed_single_node_partition_ops = ( + allowed_single_node_partition_ops + if allowed_single_node_partition_ops is not None + else [] + ) + self.dependency_viewer = _DependencyViewer(graph_module) + + def __is_node_supported(self, node: Node) -> bool: + return ( + self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node) + ) + + def propose_partitions(self) -> List[Partition]: + # partition_map is a mapping from partition id to a set of partition id's. + # The value set contains all the partition ids that can be reached by doing a + # DFS starting from the partition id in the key. + partition_map : Dict[int, Set] = collections.defaultdict(set) + + # assumptions: nodes in candidate list is sorted in topological order + assignment: Dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition + new_partition_id = itertools.count() + + # try to merge partition other_id into partition self_id + # merge only happens if the end graph doesn't contain cyclic dependency + # returns `True` when merge happens, `False` otherwise. + def maybe_merge_partition(self_id: int, other_id: int): + # merged_nodes is the union of nodes in two partition to-be-merged + merged_nodes = copy(partitions_by_id[self_id].nodes) + merged_nodes.update(partitions_by_id[other_id].nodes) + + def dfs_iter_find_cycle(all_user_nodes: Set[Node]): + for user_node in all_user_nodes: + visited_partition_ids = set() + + for path_node in self.dependency_viewer.downstreams_of(user_node): + # If any of the nodes in the dfs path of this node are in the merged_nodes + # list then there is a cycle in the graph. + if path_node in merged_nodes: + return True + + # If any of the nodes in the dfs path of this node are in the assignment + # map then we have to make sure that the partitions that these nodes belong + # to do not form a cycle with the current partitions being merged. This means + # iterating through all the nodes in all the parititons that are traversed in + # the dfs path and checking if they are in the merged_nodes list. + if path_node in assignment: + partition_id = assignment[path_node] + # If the partition id has already been visited then we know that it doesn't + # form a cycle with the current partitions being merged. + if partition_id in visited_partition_ids: + continue + p_map = partition_map[partition_id] + if self_id in p_map or other_id in p_map: + return True + + visited_partition_ids.add(partition_id) + + return False + + # check if merge would create cyclic dependency. + all_user_nodes = set() + for node in merged_nodes: + for user_node in node.users: + if user_node not in merged_nodes: + all_user_nodes.add(user_node) + + if dfs_iter_find_cycle(all_user_nodes): + # return false indicating cyclic dependency found and + # merge is aborted + return False + + # no cyclic dependency found, move forward with the merge + # updating partition nodes + partitions_by_id[self_id].nodes = merged_nodes + # updating assignment map + for node in partitions_by_id[other_id].nodes: + assignment[node] = self_id + # delete other partition + del partitions_by_id[other_id] + + partition_map[self_id] = partition_map[self_id].union(partition_map[other_id]) + del partition_map[other_id] + + return True + + def merge_single_node(node: Node, id: Optional[int]): + def _update_partition_map(node: Node, id: int): + # Iterate through all the downstream nodes of this node and update the partition map + # to indicate that there is a path from the partition id of this node to the target + # partition id. + downstream_nodes = self.dependency_viewer.downstreams_of(node) + for curr_node in downstream_nodes: + target_id = assignment.get(curr_node, None) + if target_id is not None: + partition_map[id].add(target_id) + + # Iterate through all the upstream nodes of this node and update the partition map + # to indicate that there is a path from the partition id of the upstream node to the + # current node's partition id. + upstream_nodes = self.dependency_viewer.upstreams_of(node) + for curr_node in upstream_nodes: + source_id = assignment.get(curr_node, None) + if source_id is not None: + partition_map[source_id].add(id) + + if node in assignment: + partitions_by_id[assignment[node]].remove_node(node) + + if id is None: + assignment.pop(node) + elif id not in partitions_by_id: + assignment[node] = id + partitions_by_id[id] = Partition(id=id, nodes=[node]) + _update_partition_map(node, id) + else: + assignment[node] = id + partitions_by_id[id].add_node(node) + _update_partition_map(node, id) + + logger.debug("Proposing partitions...") + + for node in reversed(self.graph_module.graph.nodes): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + merge_candidates: Dict[int, None] = {} + + # Note a limited horizontal fusion is enabled: + # when `node` is not supported, the code below attempts to fuse consumer of `node`. + # + # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut + # the fusion by adding an `else` block here to skip horizontal fusion. + if self.__is_node_supported(node) and node not in assignment: + partition_id = next(new_partition_id) + merge_single_node(node, partition_id) + merge_candidates[partition_id] = None + + # merge all possible partitions + for node in assignment: + merge_candidates[assignment[node]] = None + + merge_candidates_list = list(merge_candidates.keys()) + if len(merge_candidates_list) > 1: + self_id = merge_candidates_list[0] + for other_id in merge_candidates_list[1:]: + # note: merge partition `other_id` into partition `self_id` if + # it doesn't create cyclic dependency in the graph, otherwise, + # this is a no-op + maybe_merge_partition(self_id, other_id) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: Dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if user.op != "call_function" or \ + _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node, None) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user, None) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id # type: ignore[assignment] + for node, id in nodes_reassignment.items(): + merge_single_node(node, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + partitions_to_remove: List[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function": + assert callable(node.target) + if _get_qualified_name(node.target) not in non_compute_ops: + compute_node_count += 1 + if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops: + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logger.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes]) + + return [partition for partition in partitions_by_id.values() if partition.size() > 0] + + def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: + logger.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] + return fuse_by_partitions( + self.graph_module, + [list(partition.nodes) for partition in partitions], + prefix=prefix, + ) + + # remove non-compute-ops that sits at the boundary of a partition. + def remove_bookend_non_compute_ops(self, partitions: List[Partition]): + non_compute_ops = set(self.non_compute_ops) + + def is_non_compute_node(node: Node): + return node.op == "call_function" and \ + _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + + # cache transparent nodes + transparent_input_nodes: Dict[Node, bool] = {} + transparent_output_nodes: Dict[Node, bool] = {} + + def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): + if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + return True + if node in transparent_input_nodes: + return transparent_input_nodes[node] + if is_non_compute_node(node): + for input_n in node.all_input_nodes: + if not is_transparent_input_node(input_n, partition, removed_nodes): + transparent_input_nodes[node] = False + return False + transparent_input_nodes[node] = True + return True + transparent_input_nodes[node] = False + return False + + def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): + if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + return True + if node in transparent_output_nodes: + return transparent_output_nodes[node] + if is_non_compute_node(node): + for output_n in node.users: + if not is_transparent_output_node(output_n, partition, removed_nodes): + transparent_output_nodes[node] = False + return False + transparent_output_nodes[node] = True + return True + transparent_output_nodes[node] = False + return False + + for partition in partitions: + # Note it's ok to use `set` here, since we are only query if a node + # has been removed. We are NEVER going to iterate on nodes inside + # the set. + remove_node: Set[Node] = set() + for node in partition.nodes: + if is_non_compute_node(node) and \ + (is_transparent_input_node(node, set(partition.nodes), remove_node) or + is_transparent_output_node(node, set(partition.nodes), remove_node)): + remove_node.add(node) + + if len(remove_node) != 0: + for node in remove_node: + partition.nodes.pop(node, None) + + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions, prefix=prefix) + return fused_gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5b64eafbb60ffca5e6539c7811f2f003ce9c1f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,73 @@ +# mypy: allow-untyped-defs +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility + + +__all__ = ['PassResult', 'PassBase'] + +@compatibility(is_backward_compatible=False) +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + + def requires(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + + def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..29540fa447eb193e4b52335178e30172b8e72728 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from queue import Queue +from functools import wraps +from typing import Callable, Dict, List + +import torch.nn as nn +from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + res = fn(gm) + if res is None: + return PassResult(gm, True) + if isinstance(res, PassResult): + return res + elif isinstance(res, nn.Module): + return PassResult(res, True) + + if not inspect.isfunction(fn): + wrapped_fn.__name__ = type(fn).__name__ + + return wrapped_fn + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: List[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + +def _topological_sort_passes( + passes: List[Callable], constraints: List[Callable] +) -> List[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Contruct a graph mapping nodes to a list of their users + graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} + indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: Dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: List[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = f"Circular dependency detected within the following passes: {cycle_passes}" + raise RuntimeError(error) + + return sorted_passes + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [ + this_before_that_pass_constraint(pass_a, pass_b) + ] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: List[Callable[[nn.Module], PassResult]] + constraints: List[Callable[[Callable, Callable], bool]] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + self.passes = passes or [] + self.constraints = constraints or [] + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError("PassManager check function should only take in one variable, a module") + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for i, fn in enumerate(self.passes): + fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + logger.debug("Running pass '%s'", fn_name) + + try: + res = fn(module) + + if not isinstance(res, PassResult) and not hasattr( + res, "graph_module" + ): + raise TypeError( + f"The result of the pass {fn_name} should be type PassResult." + + "Please wrap it with pass_result_wrapper()" + ) + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + logger.debug("Graph after pass '%s': %s", fn_name, module.graph) + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + except Exception as e: + prev_pass_names = [ + p.__name__ if inspect.isfunction(p) else type(p).__name__ + for p in self.passes[:i] + ] + msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" + raise Exception(msg) from e # noqa: TRY002 + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6182972e670eaf078c5b82c32626bab1ba98f11e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/net_min_base.py @@ -0,0 +1,924 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.fx + +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg + +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + Names, + NodeList, + NodeSet, + TensorOrTensors, + Tensors, +) + +__all__ = [ + "FxNetMinimizerBadModuleError", + "FxNetMinimizerRunFuncError", + "FxNetMinimizerResultMismatchError", +] + +_LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerBadModuleError(Exception): + """ + Raised if failed to split out a minimize module + """ + + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerRunFuncError(Exception): + """ + Raised if error occurs during run_a or run_b functions + """ + + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerResultMismatchError(Exception): + """ + Raised if comparing function thinks the results are mismatching. + """ + + + +@dataclass +class _MinimizerSettingBase: + """ + Args: + `accumulate_error`: Instead of using a's input for both converted module to verify + , use the previous outputs of each converted module as input to accumulate the + errors. + + `traverse_method`: "sequential" or "binary" or "accumulate" + Determine the way of traverse the nodes in FX module. + + `find_all`: Minimizer will go through the entire model and return all problematic nodes. + + `return_intermediate`: If true, when using `run_nodes()` function to run the + model, intermediate results of all the ops will be returned as output. + """ + + accumulate_error: bool = False + traverse_method: str = "sequential" + find_all: bool = False + return_intermediate: bool = False + + def __str__(self): + settings_str = "FX Minimizer Settings:\n" + + for k, v in vars(self).items(): + settings_str += f"\t{k}: {v}\n" + + return settings_str + + +class _MinimizerBase: + """ + This class is used to automatically find problematic nodes in a model. It takes a FX + graphmodule and generate some submodules while traverse the graph. Then two functions + `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` + will be used to compare the results. + + Currently we provides two ways to traverse the graph and generate submodules. + 1. Sequential traversal: this will traverse the graph node by node and generate + one submodule with one sigle node. + 2. Binary searching: this will do a binary search style traversal on the graph. + + For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[ + [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool] + ], + settings: _MinimizerSettingBase, + module_exporter: Optional[ + Callable[ + [Tensors, torch.fx.GraphModule, str], + None + ] + ] = None, + exclusion_fn: Optional[ + Callable[[NodeList, int, int], None] + ] = None, + ): + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + self.sample_input = sample_input + self.compare_fn = compare_fn + self.module_exporter = module_exporter + self.settings = settings + self.exclusion_fn = exclusion_fn + + # Stores outputs of run_a function + self.a_outputs: Dict[str, Any] = {} + + # Stores outputs of run_b function + self.b_outputs: Dict[str, Any] = {} + + # Stores the results of compare_fn + self.results: Dict[Any, Any] = {} + + # Stores the report for the runs + self.reports: List[List[str]] = [] + + # Current iteration + self.iteration: int = 0 + + callable_nodes = { + node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS + } + ShapeProp(self.module).propagate(*self.sample_input) + self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() + + # Check if number of input in sample_input matches the number of placeholders + placeholders = [ + node.name for node in self.module.graph.nodes if node.op == "placeholder" + ] + assert len(placeholders) == len(self.sample_input) + + # Store sample_input + for i, name in enumerate(placeholders): + self.a_outputs[name] = sample_input[i] + self.b_outputs[name] = sample_input[i] + + def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_b(). + """ + raise RuntimeError("run_a() is not implemented.") + + def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_a(). + """ + raise RuntimeError("run_b() is not implemented.") + + def _store_outputs( + self, + a_result: TensorOrTensors, + b_result: TensorOrTensors, + submodule: torch.fx.GraphModule, + ): + """ + Store the outputs of self.run_a() and self.run_b() into self.a_outputs and + self.b_outputs, so that we can use them when execute preceding nodes that + use those outputs as inputs. + + Args: + a_result: Output of self.run_a(). Could be a tensor or tensors. + b_result: Output of self.run_b(). Could be a tensor or tensors. + submodule: The module that generates a_result and b_result. + """ + output_node = next( + node for node in submodule.graph.nodes if node.op == "output" + ) + + # Only one output + if isinstance(output_node.args[0], torch.fx.Node): + self.a_outputs[output_node.args[0].name] = a_result + self.b_outputs[output_node.args[0].name] = b_result + # Multiple outputs + else: + for i, arg in enumerate(output_node.args[0]): + self.a_outputs[arg.name] = a_result[i] + self.b_outputs[arg.name] = b_result[i] + + def _get_submod_inputs( + self, main_module: torch.fx.GraphModule, submod_path: str + ) -> Tuple[Tensors, Tensors]: + """ + Try get submodule inputs from stored outputs. If not found then use + torch_glow.get_submod_inputs to get the inputs. + + If accumulate_error is False, use a_input for run_a() and run_b() + otherwise use a_input for run_a and b_input for run_b. + + Args: + main_module: Top-levlel fx module. + submod_path: Path to the submodule we want to run and compare results. + + Returns: + a_input: List of tensor(s) that will be used by run_a() as submodule inputs. + b_input: List of tensor(s) that will be used by run_b() as submodule inputs. + """ + a_input = [] + b_input = [] + submodule = getattr(main_module, submod_path) + placeholders = [ + node.name for node in submodule.graph.nodes if node.op == "placeholder" + ] + + # If all placeholder can be found in stored outputs, use stored + # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` + # to get the inputs. + if set(placeholders) <= self.a_outputs.keys(): + for name in placeholders: + a_input.append(self.a_outputs[name]) + b_input.append(self.b_outputs[name]) + else: + if self.settings.accumulate_error: + print(f"Can't find previous stored outputs named {placeholders}!") + + def get_inputs(self: torch.nn.Module, inputs: Any): + nonlocal a_input + a_input = inputs + + # Use forward hook to get the inputs to the submodule + handle = submodule.register_forward_pre_hook(get_inputs) + main_module(*self.sample_input) + handle.remove() + + b_input = a_input + + if not self.settings.accumulate_error: + return a_input, a_input + + return a_input, b_input + + def _tag_nodes(self, selected_nodes: NodeSet): + """ + Tag selected nodes with tag "minimize". Nodes with the same tags will + be split to the same submodule afterwards. + + Args: + selected_nodes: Nodes that we want to minimize. We will tag those nodes + with "minimize", all preceding nodes with "main_0" and all following + nodes with "main_1". + """ + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node in selected_nodes: + node.tag = "minimize" + elif any( + n.tag in {"minimize", "main_1"} + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS + ): + node.tag = "main_1" + else: + node.tag = "main_0" + + def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]: + """ + Split self.module so that one submodule consists of `nodes` and only `nodes`. + + Args: + nodes: Nodes that we want to include in the minimize submodule. + + Returns: + split_module (torch.fx.GraphModule): the module after split. + submodule_name (str): the name of the submodule that consists of `nodes`. + """ + # Color provided nodes + self._tag_nodes(nodes) + + # Split module based on coloring + split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) + + # Find submodule containing colored nodes + submodule_name: str = "" + for child_name, _ in split_module.named_children(): # type: ignore[union-attr] + # Skip submodules we're not interested in at the moment + if "minimize" not in child_name: + continue + + if submodule_name == "": + submodule_name = child_name + else: + raise FxNetMinimizerBadModuleError( + f"Expected only one minimize submodule with nodes {nodes}" + ) + + if submodule_name == "": + raise FxNetMinimizerBadModuleError( + f"Minimize submodule was not found with nodes {nodes}" + ) + + return split_module, submodule_name # type: ignore[return-value] + + def _run_and_compare( + self, + split_module: torch.fx.GraphModule, + submod_name: str, + output_names: Names, + report_idx: int = -1 + ): + """ + Run the submodule in `split_module` that has name `submod_name` + using `self.run_a` and `self.run_b` and compare their results. + + Args: + split_module: Main module that contains the minimize submodule. + submod_name: Name of the minimize submodule. + output_names: Names of the node we want to output. If None, we + will use the original output. + """ + submodule = getattr(split_module, submod_name) + a_input, b_input = self._get_submod_inputs(split_module, submod_name) + + if len(self.reports) == 0: + self.reports.append([]) + self.iteration = 1 + + report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] + report.append("Run and compare ...") + + if output_names: + output_nodes: NodeList = [] + for node in submodule.graph.nodes: + if node.op == "output": + submodule.graph.erase_node(node) + + if node.name in output_names: + output_nodes.append(node) + + submodule.graph.output( + output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) + ) + submodule.graph.lint() + submodule.recompile() + + # Use name of args in output node as key to store comparison result + for node in submodule.graph.nodes: + if node.op == "output": + result_key = map_arg(node.args, lambda x: x.name) + + try: + a_result = self.run_a(submodule, a_input, report_idx) + b_result = self.run_b(submodule, b_input, report_idx) + self._store_outputs(a_result, b_result, submodule) + except Exception as e: + report.append(f"Exception raised when running {submod_name}: {e}") + raise FxNetMinimizerRunFuncError( # noqa: B904 + f"Exception raised when running {submod_name}: {e}" + ) + + # Compare results + names: Names = output_names + if output_names is None: + names = [str(v) for v in result_key] # type: ignore[possibly-undefined] + + numeric_result, bool_result = self.compare_fn(a_result, b_result, names) + + self.results[result_key] = numeric_result # type: ignore[possibly-undefined] + report.append(f"Numerical accuracy = {numeric_result}") + if not bool_result: + report.append(f"Result mismatch for {result_key}") + if self.module_exporter: + self.module_exporter( + a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] + ) + self.module_exporter( + b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] + ) + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") + + def _binary_search_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Recursive binary search implementation. + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + + report: List[str] = [] + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + self.iteration += 1 + self.reports.append(report) + report.append(f"Binary search iteration {self.iteration}") + report.append( + f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. " + f"Size of the interested node list is {len(nodes)}" + ) + cur_nodes: NodeSet = set(nodes) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + + except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): + + if len(nodes) == 1: + report.append( + f"This is the last node in the sub-module. " + f"Search in the current branch is successful with culprit = {cur_nodes}." + ) + self.print_report(report) + return cur_nodes + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + mid = len(nodes) // 2 + culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) + + if len(culprits) != 0 and not self.settings.find_all: + return culprits + + culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) + + if len(culprits) == 0: + report.append( + f"Further split and lowering found no errors. " + f"Unable to minimize the submodule with list of nodes: {nodes}" + ) + self.print_report(report) + + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _binary_traverse(self, nodes: NodeList) -> NodeSet: + """ + Binary search on `nodes` for culprit. + """ + return self._binary_search_impl(nodes, 0, len(nodes)) + + def _sequential_traverse(self, nodes: NodeList) -> NodeSet: + """ + Traverse `nodes` one by one and determine if any of them is a culprit. + """ + culprits: NodeSet = set() + + for node in nodes: + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Sequential traverse iteration {self.iteration}.") + report.append(f"Visit node: {node.name}") + + _LOGGER.info("Visit node: %s", node.name) + node_list: NodeList = [node] + if self.exclusion_fn is not None: + self.exclusion_fn(node_list, -1, -1) + if len(node_list) == 0: + report.append(f"User exclusion : {node.name}") + self.print_report(report) + if not self.settings.find_all: + return culprits + else: + continue + + cur_nodes: NodeSet = {node} + + if node in self.fusions: + cur_nodes = self.fusions[node] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [node.name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError): + culprits.add(node) + report.append(f"Found culprit from numeric error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + except (FxNetMinimizerRunFuncError): + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + + return culprits + + + def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int: + """ + Recursive block search implementation. + find_last_node: If True, search for the last node which result in numerics difference + if False: find first node in sorted node list + """ + report: List[str] = [] + + mid = (start_idx + end_idx) // 2 + cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:] + + if self.exclusion_fn: + self.exclusion_fn(cur_nodes_list, -1, -1) + + cur_nodes = set(cur_nodes_list) + + first_node_name = cur_nodes_list[0].name + last_node_name = cur_nodes_list[-1].name + target_node_name = last_node_name if find_last_node else first_node_name + + self.iteration += 1 + self.reports.append(report) + report.extend( + [ + "=" * 30, + f"Block search iteration {self.iteration}", + ] + ) + report.extend( + [ + f"Search for {'last' if find_last_node else 'first'} node in culprits", + f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ", + f"Subgraph constructed by {first_node_name} to {last_node_name}", + f"Targeting node: {target_node_name}", + f"Size of the interested node list is {end_idx - start_idx + 1}", + ] + ) + report_idx = len(self.reports) - 1 + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [last_node_name], report_idx) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append(f"Culprits found from node {first_node_name} to {last_node_name}.") + + if start_idx == mid: + report.extend( + [ + "This is the last node in the sub-module. ", + "Search in the current branch is successful with node :", + f"{start_idx}, node name: {nodes[start_idx].name}." + ] + ) + self.print_report(report) + return start_idx + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + else: + return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + else: + report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.") + + if start_idx == mid: + report.extend( + [ + "This is the last node in the sub-module. ", + "Search in the current branch is successful with node", + f"{start_idx}, node name: {nodes[start_idx].name}.", + ] + ) + self.print_report(report) + return start_idx + 1 if find_last_node else start_idx - 1 + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + else: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + + + def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet: + """ + Traverse topologically sorted node list + Find minimium block (start_idx, end_idx) which contains the culprit + 1st pass: search for end_idx by finding the last node in culprit block + where Numerical accuracy (0, end_idx) > threshold + 2nd pass: search for start_idx by finding the first node in culprit block + where Numerical accuracy (start_idx, end_idx) < threshold + Form minimum block by (start_idx - 1, end_idx) + """ + culprits: NodeSet = set() + first_node_name = nodes[0].name + last_node_name = nodes[-1].name + last_node_report = [f"Block search from {first_node_name} to {last_node_name}"] + last_node_report.append("*" * 50) + self.reports.append(last_node_report) + + start_idx = 0 + end_idx = len(nodes) - 1 + run_both = True if find_last_node is None else False + + # step 1: find (0, end_idx) of culprit block + if run_both or find_last_node: + last_node_report.append("Start searching for last node in culprit") + self.print_report(last_node_report) + end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) + last_node_report.extend( + [ + "Finish Pass 1", + f"Find end_idx = {end_idx}:{nodes[end_idx].name}" + ] + ) + self.print_report(last_node_report) + + # step 2: reduce culprit block to (start_idx, end_idx) + if run_both or not find_last_node: + first_node_report = ["Start searching for first node in culprit"] + self.print_report(first_node_report) + start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False) + first_node_report.append("*" * 50) + self.reports.append(first_node_report) + first_node_report.extend( + [ + "Finish Pass 2", + f"Find start_idx = {start_idx}:{nodes[start_idx].name}" + ] + ) + self.print_report(first_node_report) + + # step 3: form module with minimum culprits + culprits.update(nodes[start_idx:end_idx + 1]) + result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"] + self.reports.append(result_report) + self.print_report(result_report) + return culprits + + + def _defined_traverse(self, nodes: NodeList) -> NodeSet: + """ + run user defined `nodes` and determine if it is a culprit. + """ + culprits: NodeSet = set() + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, -1, -1) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + report = [f"Defined graph from {first_node_name} to {output_node_name}"] + cur_nodes: NodeSet = set(nodes) + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append(f"Found culprit {cur_nodes}") + self.print_report(report) + return culprits + + return culprits + + def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: + culprits: NodeSet = set() + nodes_to_run: NodeSet = set() + + # find_all is not supported for accumulate traversal because all the + # ops run on NNPI. So we return after the first op that raises error. + if self.settings.find_all: + print("'Find All' mode is not supported in accumulate traversal.") + return culprits + + for node in nodes: + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Accumulate traverse iteration {self.iteration}.") + + nodes_to_run.add(node) + + node_name = node.name + if node_name is not None and isinstance(node_name, tuple): + node_name = node_name[0] + assert node_name is not None and isinstance( + node_name, str + ), f"minimize: node_name: {node_name}" + + report.append(f"Add node: {node_name}") + + try: + split_module, submod_name = self._build_submodule(nodes_to_run) + self._run_and_compare(split_module, submod_name, [node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + culprits.add(node) + report.append(f"Found culprit {node}") + self.print_report(report) + return culprits + + return culprits + + def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + cur_nodes: NodeSet = set(nodes) + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + cur_nodes = set(nodes) + else: + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + report: List[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f" Nodes block {self.iteration}.") + report.append( + f"From node index {start_idx} to {end_idx-1}. " + f"Size of the interested node list is {len(nodes)}" + ) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, []) + except (FxNetMinimizerResultMismatchError): + culprits.update(cur_nodes) + report.append(f"Found culprit from numeric error: {cur_nodes}") + self.print_report(report) + return culprits + except (FxNetMinimizerRunFuncError): + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {cur_nodes}") + self.print_report(report) + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + start_idx = 0 + num_nodes = len(all_nodes) + idx = 0 + culprits = set() + while idx < num_nodes: + node = all_nodes[idx] + if (node.name in skip_nodes): # skip the node + if idx > start_idx: + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) + start_idx = idx + 1 + elif idx == num_nodes - 1 and start_idx <= idx: # last node + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1) + idx += 1 + + return culprits + + + + def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: + """ + Collect nodes in the model that between nodes with name of `start` and `end`. + These two nodes are also included. + """ + nodes: NodeList = [] + add_node = start is None + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node.name == start: + add_node = True + + if add_node: + nodes.append(node) + + if node.name == end: + break + + return nodes + + def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): + """ + Run part of the model from `start` node to `end` node. If `start` is None + then we start from the beginning of the model. If `end` is None then we + stop at the end of the model. + + Args: + start: The name of the node which is the first node of the submodule + we want to run. If set to None, then we'll start with the first + node of the model. + end: The name of the node which is the last node of the submodule we + want to run. If set to None, we'll end with the last node of the + model. + """ + nodes = self._collect_nodes(start, end) + cur_nodes = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + output_names = [] + if self.settings.return_intermediate: + output_names = [node.name for node in nodes] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, output_names) + except ( + FxNetMinimizerRunFuncError, + FxNetMinimizerResultMismatchError, + ) as e: + print(e) + + def print_report(self, report: List[str]): + for i in range(len(report)): + if i > 0: + print(" . " + report[i]) + else: + print(report[i]) + + def print_reports(self): + for report in self.reports: + self.print_report(report) + + def minimize( + self, + start: Optional[str] = None, + end: Optional[str] = None, + skip_nodes: Optional[List] = None, + find_last_node: Optional[bool] = None, + ) -> NodeSet: + """ + Minimizing the model from node with name `start` to node with name `end` base + on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors. + + Args: + start: The name of the node where we want to start minimizing. If set + to None, then we'll start with the first node of the model. + end: The name of the node where we want to terminate minimizing. If + set to None, we'll end with the last node of the model. + skip_nodes: The names of nodes where we want to skip during minimizing. + It'll create subgraphs without these skip nodes under the hood. + Only applicable in mode "skip". + find_last_node: True if only last_node of a culprits is needed in mode "block". + False if only the first_node of a culprits is needed. + Only applicable in mode "block". + + Returns: + nodes: A list of nodes that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors during minimizing. + """ + + print(self.settings) + print(self.module.graph) + + nodes = self._collect_nodes(start, end) + + if self.settings.traverse_method == "sequential": + return self._sequential_traverse(nodes) + + if self.settings.traverse_method == "binary": + return self._binary_traverse(nodes) + + if self.settings.traverse_method == "accumulate": + return self._accumulate_traverse(nodes) + + if self.settings.traverse_method == "skip": + if (skip_nodes is None): + raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") + return self._skip_traverse(nodes, skip_nodes) + + if self.settings.traverse_method == "defined": + return self._defined_traverse(nodes) + + if self.settings.traverse_method == "block": + return self._block_traverse(nodes, find_last_node) + + raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!") diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py new file mode 100644 index 0000000000000000000000000000000000000000..57edabc0a55ae6ddb0738ea83edb49b2d5823930 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/operator_support.py @@ -0,0 +1,215 @@ +# mypy: allow-untyped-defs +import abc +import typing as t + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from .shape_prop import TensorMetadata +from .tools_common import get_node_target, CALLABLE_NODE_OPS + + +__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain'] + +# fx.Node.target typename, as returned by `get_node_target()` +TargetTypeName = str + +# Arguments' dtypes for a given node, see `OperatorSupport` +SupportedArgumentDTypes = t.Optional[ + t.Tuple[ + t.Sequence[t.Sequence[torch.dtype]], + t.Dict[str, t.Sequence[torch.dtype]], + ] +] + +SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] + + +@compatibility(is_backward_compatible=False) +class OperatorSupportBase(abc.ABC): + """Interface for determining if a fx.Node is supported by a backend""" + @abc.abstractmethod + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + raise NotImplementedError + + +@compatibility(is_backward_compatible=False) +class OperatorSupport(OperatorSupportBase): + """ + `_support_dict` maps node.target typename to supported inputs dtypes. + + node.target typename is retrieved using helper function `get_node_target()` + + If supported inputs dtypes is None, it means any dtype is supported, else + we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). + + The first tuple ([dtypes], ...) indicates what dtypes are supported for + inputs in node.args and the second dict {"name": [dtypes], ...} indicates + what dtypes are supported for inputs in node.kwargs. + + For inputs in args, if we don't want to check it, we can put None there, + e.g. (None, [torch.float]) indicates that we don't care about the type of + the first input in args. And for inputs in kwargs, if not listed, will not + be checked. + """ + + _support_dict: SupportDict + + def __init__( + self, + support_dict: t.Optional[SupportDict] = None + ): + self._support_dict = support_dict or {} + + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + """ + Args: + `submodules`: mapping from module name to the module. This can be + retrieved by calling model.named_modules(). + + `node`: a Fx node that we want to determine whether it's supported. + + Returns: + `is_supported`: whether the arg `node` is supported. + """ + if node.op not in CALLABLE_NODE_OPS: + return True + + target = get_node_target(submodules, node) + + # Target not found in _support_dict meaning that we don't support this op at all + if target not in self._support_dict: + return False + + # The rule for target is None meaning that we accept any dtype + if self._support_dict[target] is None: + return True + + args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] + + # Check args dtypes + for i, dtypes in enumerate(args_dtypes): + if len(node.args) <= i: + break + + # None indicates we don't care about the dtype of args[i] + if dtypes is None: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.args[i], torch.fx.Node): + continue + + arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] + if arg_dtype not in dtypes: + return False + + # Check kwargs dtypes + for k, dtypes in kwargs_dtypes.items(): + if k not in node.kwargs: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.kwargs[k], torch.fx.Node): + continue + + kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] + if kwarg_dtype not in dtypes: + return False + + return True + + +# ====================================================================== +# Functional interfaces and utils for defining basic operator support logic +# and composing them into more complex ones +# ====================================================================== + +IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] + + +@compatibility(is_backward_compatible=False) +def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: + """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance + + `IsNodeSupported` has the same call signature as + `OperatorSupportBase.is_node_supported` + """ + class FunctionalOperatorSupport(OperatorSupportBase): + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return is_node_supported(submodules, node) + return FunctionalOperatorSupport() + + +@compatibility(is_backward_compatible=False) +def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns False if + any of it reports False. + """ + def _chain(submods, node) -> bool: + return all( + x.is_node_supported(submods, node) + for x in op_support + ) + return create_op_support(_chain) + + +@compatibility(is_backward_compatible=False) +def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns True if + any of it reports True. + """ + def _any_chain(submods, node) -> bool: + return any( + x.is_node_supported(submods, node) + for x in op_support + ) + return create_op_support(_any_chain) + + +@compatibility(is_backward_compatible=False) +class OpSupports: + """A set of atomic `OperatorSupportBase` instances that can be combined together + to form more complex operator support logic. + """ + @classmethod + def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: + """Report a node as non-supported, if any of its arguments is of dtype""" + + def _decline_if_input_dtype( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + for arg in node.all_input_nodes: + arg_dtype = _get_arg_dtype(arg) + if arg_dtype == dtype: + return False + return True + return create_op_support(_decline_if_input_dtype) + + @classmethod + def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase: + """ + If a node has a name that is in the disallow set, reported it as non-supported. + """ + def _decline_if_node_in_names( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + return node.name not in disallow_set + return create_op_support(_decline_if_node_in_names) + + +def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: + assert isinstance(arg, torch.fx.Node) + tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] + dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] + return dtype diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..5979e29fcc6b2650a1f73be4845e2ad3dcda0920 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/param_fetch.py @@ -0,0 +1,66 @@ +from torch.fx.graph_module import GraphModule +from typing import Any, Callable, Dict, List, Tuple, Type +import torch +import torch.nn as nn + +from torch.fx._compatibility import compatibility + +__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +@compatibility(is_backward_compatible=False) +def default_matching(name: str, target_version: int) -> str: + """Default matching method + """ + return name + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + ), + torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + +@compatibility(is_backward_compatible=False) +def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: Dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly.") + return attrs_for_lowering + +@compatibility(is_backward_compatible=False) +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. + """ + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc4ff5e07090588ec92edd49111012f4a69ba6b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/pass_manager.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +from functools import wraps +from inspect import unwrap +from typing import Callable, List, Optional +import logging + +logger = logging.getLogger(__name__) + +__all__ = [ + "PassManager", + "inplace_wrapper", + "log_hook", + "loop_pass", + "this_before_that_pass_constraint", + "these_before_those_pass_constraint", +] + +# for callables which modify object inplace and return something other than +# the object on which they act +def inplace_wrapper(fn: Callable) -> Callable: + """ + Convenience wrapper for passes which modify an object inplace. This + wrapper makes them return the modified object instead. + + Args: + fn (Callable[Object, Any]) + + Returns: + wrapped_fn (Callable[Object, Object]) + """ + + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + return gm + + return wrapped_fn + +def log_hook(fn: Callable, level=logging.INFO) -> Callable: + """ + Logs callable output. + + This is useful for logging output of passes. Note inplace_wrapper replaces + the pass output with the modified object. If we want to log the original + output, apply this wrapper before inplace_wrapper. + + + ``` + def my_pass(d: Dict) -> bool: + changed = False + if 'foo' in d: + d['foo'] = 'bar' + changed = True + return changed + + pm = PassManager( + passes=[ + inplace_wrapper(log_hook(my_pass)) + ] + ) + ``` + + Args: + fn (Callable[Type1, Type2]) + level: logging level (e.g. logging.INFO) + + Returns: + wrapped_fn (Callable[Type1, Type2]) + """ + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + logger.log(level, "Ran pass %s\t Return value: %s", fn, val) + return val + + return wrapped_fn + + + +def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): + """ + Convenience wrapper for passes which need to be applied multiple times. + + Exactly one of `n_iter`or `predicate` must be specified. + + Args: + base_pass (Callable[Object, Object]): pass to be applied in loop + n_iter (int, optional): number of times to loop pass + predicate (Callable[Object, bool], optional): + + """ + assert (n_iter is not None) ^ ( + predicate is not None + ), "Exactly one of `n_iter`or `predicate` must be specified." + + @wraps(base_pass) + def new_pass(source): + output = source + if n_iter is not None and n_iter > 0: + for _ in range(n_iter): + output = base_pass(output) + elif predicate is not None: + while predicate(output): + output = base_pass(output) + else: + raise RuntimeError( + f"loop_pass must be given positive int n_iter (given " + f"{n_iter}) xor predicate (given {predicate})" + ) + return output + + return new_pass + + +# Pass Schedule Constraints: +# +# Implemented as 'depends on' operators. A constraint is satisfied iff a list +# has a valid partial ordering according to this comparison operator. +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: List[Callable] +): + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def this_before_that_pass_constraint(this: Callable, that: Callable): + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +def these_before_those_pass_constraint(these: Callable, those: Callable): + """ + Defines a partial order ('depends on' function) where `these` must occur + before `those`. Where the inputs are 'unwrapped' before comparison. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [ + loop_pass(pass_b, 3), + loop_pass(pass_a, 5), + ] + + constraints = [ + these_before_those_pass_constraint(pass_a, pass_b) + ] + ``` + + Args: + these (Callable): pass which should occur first + those (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return unwrap(a) != those or unwrap(b) != these + + return depends_on + + +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): list of passes. A pass is a + callable which modifies an object and returns modified object + constraint (Optional[List[Callable]]): list of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + """ + + passes: List[Callable] + constraints: List[Callable] + _validated: bool = False + + def __init__( + self, + passes=None, + constraints=None, + ): + self.passes = passes or [] + self.constraints = constraints or [] + + @classmethod + def build_from_passlist(cls, passes): + pm = PassManager(passes) + # TODO(alexbeloi): add constraint management/validation + return pm + + def add_pass(self, _pass: Callable): + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint): + self.constraints.append(constraint) + self._validated = False + + def remove_pass(self, _passes: List[str]): + if _passes is None: + return + passes_left = [] + for ps in self.passes: + if ps.__name__ not in _passes: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def replace_pass(self, _target, _replacement): + passes_left = [] + for ps in self.passes: + if ps.__name__ == _target.__name__: + passes_left.append(_replacement) + else: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def validate(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def __call__(self, source): + self.validate() + out = source + for _pass in self.passes: + out = _pass(out) + return out diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..76435b9d318af1b5d62d6dce3b9e8c83fd615516 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/reinplace.py @@ -0,0 +1,675 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.utils._pytree import tree_map_only +from torch.utils import _pytree as pytree +from torch.multiprocessing.reductions import StorageWeakRef + +import _operator +from enum import Enum +import itertools +from typing import Set, Dict +from collections import defaultdict + +__all__ = ['reinplace'] + +class _ViewType(Enum): + NonView = 0 + SingleOutputView = 1 + MultiOutputView = 2 + +def _is_view_op(tgt): + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + +def _get_view_type(tgt) -> _ViewType: + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + if first_arg.alias_info is not None and not first_arg.alias_info.is_write: + # check if op is a multi-output view + if '*' in first_arg.alias_info.after_set: + return _ViewType.MultiOutputView + else: + return _ViewType.SingleOutputView + return _ViewType.NonView + + +# Stores a bunch of metadata related to functionalization each node. +# Relevant metadata: +# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) +# The fake tensor output from running the current node +# n.meta['view_of']: Node +# If the current node n is a view of some base tensor, the 'view_of' field tells us which +# view node was used to generate the current node (a view tensor). +# This information actually makes `fake_result` redundant, but we can use `fake_result` +# to sanity check that our aliasing information is correct. +@compatibility(is_backward_compatible=False) +class _FunctionalizationMetadataProp(torch.fx.Interpreter): + + def run_node(self, node: Node): + self.node_counter += 1 + result = super().run_node(node) + node.meta['fake_result'] = result + node.meta['node_idx'] = self.node_counter + + # (1) Update metadata with the list of nodes that are used by this node + # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. + # We don't want to treat it as "being used as an input". + node_args = node.args + if node.target is torch.ops.aten.copy_.default: + node_args = node_args[1:] + + # (2) Update metadata to track aliasing information about view tensor nodes. + if node.op == 'call_function': + view_type = _get_view_type(node.target) + if view_type == _ViewType.SingleOutputView: + assert isinstance(node.args[0], Node) + node.meta['view_of'] = node.args[0] + elif view_type == _ViewType.MultiOutputView: + self.multi_output_view_nodes[node] = node.args[0] + + # Check if we returned a multi-output view, + # and we're now grabbing the individual views from the output. + # + # For multi-output views, we want to map each output view to the base, + # but this mapping involves two separate nodes in FX IR. + # e.g. "a, b = x_1.split(...)" becomes: + # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) + # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) + # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) + # And we'd like to set: + # getitem1.meta['view_of'] = x_1 + elif node.target is _operator.getitem: + list_arg = node.args[0] + maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) + if maybe_base_of_view is not None: + # Note: we could also track indexing info here for multi-output views. + # I don't think this metadata is strictly needed for de-functionalization. + assert isinstance(maybe_base_of_view, Node) + node.meta['view_of'] = maybe_base_of_view + + if 'view_of' in node.meta: + # We're linking the current node with its first argument as views. + # Assert here that this is actually the case, and their storages are the same. + assert isinstance(node.meta['fake_result'], FakeTensor) + assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) + view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) + assert view_storage == base_storage + return result + + + + def propagate(self, *args): + self.multi_output_view_nodes = {} + self.node_counter = -1 + + with FakeTensorMode() as mode: + fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args] + return super().run(*fake_args) + +def _schemas_match(functional_schema, inplace_schema): + names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name + arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( + a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) + # for the inplace op, its first argument should be mutable + assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write + # and its remaining arguments shouldn't be. + assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) + return names_match and arg_types_match + +# TODO: this should be beefed up to be able to properly re-inplace with: +# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) +# - out= ops (e.g. angle -> angle.out) +# TODO: we should also figure this info out using torchgen. +def _maybe_get_inplace_op(op): + # __module__ seems broken; it returns torch._ops.aten which doesn't exist + if not isinstance(op, torch._ops.OpOverload): + return None + # Some view ops have inplace variants (as_strided_, etc), + # but we do NOT want the reinplacing pass to directly add these into the program. + # (they'll require extra special handling, aren't aren't really useful for perf anyway) + if _is_view_op(op): + return None + op_namespace = op.__module__.split(".")[-1] + op_base_name = op.overloadpacket.__name__ + maybe_namespace_module = getattr(torch.ops, op_namespace) + maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) + if maybe_inplace_op is None: + return None + + inplace_overloads = [ + getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() + ] + inplace_overloads_with_matching_schemas = [ + f + for f in inplace_overloads + if _schemas_match(op._schema, f._schema) + ] + # Just because foo() and foo_() are both existing operators, + # They aren't guaranteed to have compatible schemas. + # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, + # Even though several overloads of pow_ exist. + if len(inplace_overloads_with_matching_schemas) == 0: + return None + assert len(inplace_overloads_with_matching_schemas) == 1 + inplace_op = inplace_overloads_with_matching_schemas[0] + return inplace_op + +_VIEW_INVERSE_MAP = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} + +# This function, given a set of set of (aliased) tensor nodes, +# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index +# in the node ordering. +def _get_all_later_node_usages(tensor_aliases: Set[Node], op_index: int): + def _add_if_tensor(x, set_): + if isinstance(x, FakeTensor): + set_.add(StorageWeakRef(x._typed_storage())) + + nodes_used_after = set() + for t in tensor_aliases: + # get all nodes that use the current alias + usage_nodes = t.users + for n in usage_nodes: + # We only care about usages after the current node + if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: + continue + # We also don't care about intermediate view ops. + # They only matter if their output is then used elsewhere + # (either in an out-of-place op, or as an output to the function). + if n in tensor_aliases: + if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: + continue + nodes_used_after.add(n) + return nodes_used_after + +# Given an op that we're trying to re-inplace, "b = foo(a)", +# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" +# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: +# If there are any aliases in the alias_set(a) that satisfy: +# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" +# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata +# as "alias" +def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: + def matching_view_metadata(a, b): + return a.size() == b.size() and \ + a.stride() == b.stride() and \ + a.storage_offset() == b.storage_offset() + + view_inverse_nodes = set() + # Go through them in node order, so we can see chains of view_scatter ops. + for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): + if n.target not in _VIEW_INVERSE_MAP: + continue + base = n.args[0] + mutated_view = n.args[1] + assert isinstance(base, Node) + assert isinstance(base.meta['fake_result'], FakeTensor) + assert isinstance(mutated_view, Node) + assert isinstance(mutated_view.meta['fake_result'], FakeTensor) + # Check that this view_inverse op actually corresponds to taking doing the inverse + # of one of our existing self_alias nodes. + original_view = _VIEW_INVERSE_MAP[n.target] + for self_alias in self_aliases: + # We're looking for some alias of the self arg, "alias", + # that was created from some op `alias = foo(base, args...)` + # such that the current _scatter op "inverts" that foo call. + # We can check that by running the original op again, and checking that the strides match. + if 'view_of' not in self_alias.meta: + continue + self_alias_base = self_alias.meta['view_of'] + try: + # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse + # of the current alias we're looking at. + view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) + expected_metadata = self_alias.meta['fake_result'] + # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. + if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ + matching_view_metadata(view_replay_metadata, expected_metadata): + view_inverse_nodes.add(n) + except Exception: + continue + + return view_inverse_nodes + + +@compatibility(is_backward_compatible=True) +def reinplace(gm, *sample_args): + """ + Given an fx.GraphModule, modifies it to perform "reinplacing", + mutating the nodes of the graph. + We look for out-of-place op call sites like `b = a.add(...)`, + and convert them to be inplace (`b = a.add_(...)`), + as long as the input to the current operator ("a") isn't re-used + anywhere later in the graph. + + This pass currently expects to operate on a **functional, ATen** graph. + This can be obtained by running `make_fx(functionalize(f))`. + + Sample inputs are needed to determine aliasing relationships of the inputs. + In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + inputs to the program. + + Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: + + (1) Perform some initial checks on the metadata of "a" and "args..." + that can disqualify them from being reinplaced. + + (1a) Check that the self argument we're attempting to reinplace + has acceptable dtype/size metadata to reinplace with. + + For example, if we have: + a = torch.ones(1) + b = torch.ones(10) + out = torch.add(a, b) + We can't turn that into + a.add_(b) + Because that would require resizing "a". + + Similarly, we can't convert torch.ge(a, b) into a.ge_(b), + because that would require changing a's dtype (from e.g. float32 to bool). + Note that in this specific example, we could technically do better.. + + If we see the pattern: + a_1 = a.ge(b) + a_2 = aten._to_copy(a_1, a.dtype) + Then we this should be valid to completely re-inplace + (this is exactly what functionalization will emit when it sees a.ge_(b)). + + This optimization is only really important for user programs + that directly use inplace comparison ops though. + + We also cannot re-inplace on tensors that have overlapping memory, + e.g. torch.ones(1).expand(4, 4).add_(1) + + (1b) Check if "a" is an alias of any of the program inputs. + + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. + + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like "a.copy_(...)", + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + which will later be overwritten by the copy_() call. + + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. + + (1c) Check if "a" and "args..." alias + + For example, re-inplacing to create code like the below + isn't guaranteed to be sound: + + aten.mul_(a, a) + + (2) Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to "b = foo_(a)". + + There are a few caveats to this, explained in more detail below: + (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + (b) If "a" is a repeat argument in `foo()`, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do aten.mul_(a, a). + So we'll just ban re-inplacing in this case. + It's only a problem if "a" (or that view) is later passed + (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. + + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a) + torch.ops.aten.fill_(b, 0) + return d + + Functionalization will emit the following: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a, 0, 1) + b_updated = torch.ops.aten.fill(b, 0) + a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) + return a_updated + + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely, the expensive diagonal_scatter call, if we re-inplace the add(). + + So, for every `alias in alias_set(a)`, instead of checking + that "alias" is not used anywhere later in the graph, + we check that + EITHER: + (a) alias is not used anywhere later in the graph + OR: + (b) alias is used exactly once later on in the graph, + in the following op: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + (i) "foo_scatter" is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies: + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + (ii) "args..." are the same between the foo() and foo_scatter() calls. + + (3) Perform the actual re-inplacing on foo! + + (3b) is the common case, but special care is needed for {view}_scatter (3a) + + (3a) {view}_scatter ops. + + Consider this program: + a = torch.zeros(2, 2) + b = torch.ones(2) + a[0] = b + + Post functionalization, that will look like: + a = torch.zeros(2) + b = torch.ones(1) + a_updated = torch.select_scatter(a, b, 0, 0) + + In this case though, there is no "functional" op to re-inplace! + Instead, we'd like to directly remove toe select_scatter call. + We already know from (3) that this is valid, + because "a" has no later usages in the graph. + + We perform the re-inplacing on the {view}_scatter op like so + Before: + a_updated = torch.select_scatter(a, b, args...) + After: + a_slice = a.select(a, args...) + a_slice.copy_(b) + + (3b) Otherwise, replace the functional op with its inplace variant. + Before: + b = foo(a, args...) + After: + a.foo_(args...) + + (4) Finally, after converting either: + Before: + b = foo(a) + After: + foo_(a) + or + Before: + b = {slice}_scatter(a, mutated_slice, args...) + After: + slice = {slice}(a, args...) + slice.copy_(mutated_slice) + + We now need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], + That maps a given tensor storage to the set of all nodes that take in that storage + as an input. + Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused + together. + + (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" + during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. + """ + _FunctionalizationMetadataProp(gm).propagate(*sample_args) + + # Useful debug printing + # def _print(x): + # if isinstance(x, FakeTensor): + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') + + # for n in gm.graph.nodes: + # print(n.format_node()) + # if hasattr(n, 'meta'): + # print(f'node_idx: {n.meta["node_idx"]}') + # if 'fake_result' in n.meta: + # tree_map(_print, n.meta['fake_result']) + # if 'view_of' in n.meta: + # print(f'view_of: {str(n.meta["view_of"])}') + # print() + + # We need to know which nodes correspond to inputs (or their aliases) + # so we know not to re-inplace them. + # NOTE: later, we'll need to add an optimization for fully recovering performance + # on programs that mutate inputs. + input_storages = { + StorageWeakRef( + node.meta['fake_result']._typed_storage() + ) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))} + + # We also need to know for a given node, what are all of its aliasing nodes. + storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) + for n in gm.graph.nodes: + if 'fake_result' in n.meta: + # Tree-mapping because some ops can return lists of tensors. + def _add_to_map(x): + if isinstance(x, FakeTensor): + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) + pytree.tree_map_(_add_to_map, n.meta['fake_result']) + + # inplace-ify functional ops, subject to the constraints written below. + all_later_view_inverse_nodes_to_delete = set() + for idx, node in enumerate(gm.graph.nodes): + if node.op == 'call_function': + + # Today, the re-inplace pass on directly acts on: + # - functional ops with an inplace variant + # - {view}_scatter ops that can be potentially removed from the graph. + # Both of these ops take in tensor first args, so filtering on this condition + # makes the later code simpler. + # We should revisit this at some point though, particularly when we also want + # the reinplacer to be able to handle out= and mutable operators + # and tensorlist first args (like `_foreach_` ops). + if not isinstance(node.target, torch._ops.OpOverload): + continue + if len(node.target._schema.arguments) < 1: + continue + if type(node.target._schema.arguments[0].type) != torch.TensorType: + continue + + # Step 1a: Check that the self argument we're attempting to reinplace + # has the same size/stride as the output. + # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) + # As it would require resizing scalar_tensor. + # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), + # this is probably an optimization to revisit later). + self_arg = node.args[0] + self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) + node_flattened = pytree.tree_leaves(node.meta['fake_result']) + self_has_wrong_metadata = False + if len(self_flattened) == len(node_flattened): + for self_meta, node_meta in zip(self_flattened, node_flattened): + if self_meta.numel() != node_meta.numel(): + self_has_wrong_metadata = True + if self_meta.dtype != node_meta.dtype: + self_has_wrong_metadata = True + # We also cannot re-inplace on tensors that have internal memory overlap. + # e.g. torch.ones(1).expand(4, 4).add_(1) + if torch._debug_has_internal_overlap(self_meta) == 1: + self_has_wrong_metadata = True + # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, + # Since users should never really be calling the functional "torch.ops.aten.resize" + # op directly in their programs. + if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: + continue + + # Step 1b: ensure that the op we're trying to re-inplace isn't a program input + self_arg_name = self_arg.name + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + if self_arg_storage in input_storages: + # TODO: later, add the optimization for handling `copy_()` calls in the graph. + continue + if len([x for x in node.args if x is self_arg]) > 1: + # Step 1c: + # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, + # so we prevent re-inplacing in this case. + continue + + self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_aliases = storage_to_nodes[self_arg_storage] + + # First, we find all later usages of any of the aliases of self_arg. + later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) + # Then, we check if any of those later usages are actually view_scatter ops + # that are safe to fully remove. + later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) + + # Step 2: Check to see if the input to the op is re-used later in the graph. + # If not (same goes for its aliases), then this op is safe to re-in place. + # This is a slightly roundabout way to check that there are no later usages of the current self argument. + # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) + can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 + if not can_reinplace: + continue + + # Step 3a: Special handling for when we see *_scatter operators. + # When we see an operator like `b = torch.slice_scatter(a, ...)`, + # instead of trying to "inplace" it into a.slice_scatter_(..._), + # we would prefer to remove it from the graph entirely, + # and instead copy_() the slice directly into the larger tensor. + # See the description of the algorithm for a full example. + if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: + view_op = _VIEW_INVERSE_MAP[node.target] + # Before: + # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) + # After: + # slice = torch.ops.aten.slice.default(base, args...) + # slice.copy_(mutated_slice) + with gm.graph.inserting_before(node): + mutated_slice_node = node.args[1] + remaining_slice_args = node.args[2:] + slice_node = gm.graph.create_node( + 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) + copy_node = gm.graph.create_node( + 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) + # Add the slice_scatter node to our "nodes to delete" list. + all_later_view_inverse_nodes_to_delete.add(node) + + + else: + # Step 3b: Check to see if this operator has an inplace variant. + maybe_inplace_op = _maybe_get_inplace_op(node.target) + if maybe_inplace_op is None: + continue + # And if so, replace it with its inplace variant. + node.target = maybe_inplace_op + + # At this point, 'storage_to_nodes' will be stale. + # Now that we're inplacing `b = foo(a)`, we need to effectively + # union together the dict values for b and a's storage. + # Hmm... morally I think we also want to keep the `fake_result` metadata + # up to date here, but I'm not sure how easy it is to do. + # Maybe it's fine to wait until the end of the pass to update it. + curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) + storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) + storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) + + # Need to remember the view_scatter view nodes we found so we can remove them alter. + all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) + + # Step 4: + # Now that we've replaced b = a.foo() with a.foo_(), + # We need to replace any later usages of "b" with "a" + for old in itertools.chain([node], later_view_inverse_node_usages): + new = old.args[0] + nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] + for node_to_update in nodes_to_update: + new_args = [] + args = node_to_update.args + + def replace_arg(a): + if a == old: + return new + return a + + # First, replace usages of "b" with "a" + node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) + node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) + + # Second, update our storage_to_nodes data structure. + old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) + node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) + + old_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in old_flattened_res if isinstance(x, FakeTensor)} + node_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in node_flattened_res if isinstance(x, FakeTensor)} + + # This will happen if we're updating a view op, e.g. + # e.g. replacing + # x = view(old) + # x = view(new) + # When that happens, we need to make sure to keep our + # storage mapping up to date. + # + # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, + # or multiple tensors that all share the same storage. + # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. + if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: + new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) + new_res_storage = { + StorageWeakRef( + x._typed_storage() + ) for x in new_flattened_res if isinstance(x, FakeTensor)} + assert len(new_res_storage) == 1 + (old_ref,) = old_res_storage + (new_ref,) = new_res_storage + (node_ref,) = node_res_storage + # Technically, "old_ref" and all its aliases will remain + # in our mapping. + # That should be fine though, since we deleted "old" + # from the graph at this point. + storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) + storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) + + # Step 4: delete any _scatter nodes that we de-functionalized + # Need to take care not to delete any of these nodes until after *all* modifications + # to the graph are finished. + for to_delete in all_later_view_inverse_nodes_to_delete: + gm.graph.erase_node(to_delete) + + + gm.recompile() + return gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/runtime_assert.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/runtime_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..01803600d021ab59b42222f6163f23d82f5d29ed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/runtime_assert.py @@ -0,0 +1,605 @@ +# mypy: allow-untyped-defs +import functools +import logging +import operator +import sys +from typing import Any, Dict, Optional, Set, TYPE_CHECKING + + +# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow +if TYPE_CHECKING: + import sympy + + from torch.fx.experimental.symbolic_shapes import ShapeEnv +else: + ShapeEnv = Any + +import torch +import torch.utils._pytree as pytree +from torch import fx +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.proxy_tensor import py_sym_types +from torch.fx.experimental.sym_node import SymNode +from torch.fx.graph_module import GraphModule + + +__all__ = ["insert_deferred_runtime_asserts"] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") + + +def _get_example_value(node: fx.Node) -> Optional[str]: + """ + Get the example value key for a node, since dynamo uses "example_value" + while non-strict export uses "val. + """ + if "example_value" in node.meta: + return node.meta["example_value"] + elif "val" in node.meta: + return node.meta["val"] + else: + return None + + +def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: + val = _get_example_value(node) + if isinstance(val, py_sym_types): + return val.node.expr + return None + + +@compatibility(is_backward_compatible=True) +def insert_deferred_runtime_asserts( + gm: GraphModule, + shape_env: ShapeEnv, + name: str, + export: bool = False, +) -> None: + """ + During tracing, we may have discovered that some data-dependent values + had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime + that x.item() >= 0. This asserts can happen unpredictably during fake + tensor propagation, so we cannot conveniently insert them into the FX graph + when they occur. Instead, we accumulate them in the ShapeEnv, and in this + pass insert them into the graph as proper tests. + + This pass also deduplicates size-related computation, CSE-ing ops that produce + symbolic values and/or are involved in runtime asserts. Additionally, shape calls + (size/stride/storage_offset) are turned into compute on input sizes if possible, + allowing intermediate tensors to be freed earlier. For example, here dynamo will + DCE the cat and repeat calls: + + z = torch.cat([x, x], dim=0) # 2*s0 + w = z.repeat(y.shape[0]) # 2*s0*s1 + _w = w.shape[0] + # something with _w, but not w ... + + # turns into -> + _w0 = 2 * s0 + _w = _w0 * s1 + + # where s0, s1 are either SymInt graph inputs, or the result of added size calls + + Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert + the same expression, and redundant constrain_range calls are also deduplicated. + Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate + information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol, + and we delete all previous calls, adding bound checks at the end of this pass. + """ + + # Import sympy locally + import sympy + + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook + from torch.fx.experimental.symbolic_shapes import ( + _has_uninterpretable_sympy_function, + CallMethodKey, + cast_symbool_to_symint_guardless, + ConvertIntKey, + DivideByKey, + free_symbols, + InnerTensorKey, + resolve_unbacked_bindings, + ) + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.reference import PythonReferenceAnalysis + from torch.utils._sympy.value_ranges import ValueRanges + + # TODO: Request simplification on runtime asserts before emitting them + ras_by_symbol = shape_env.deferred_runtime_asserts.copy() + graph = gm.graph + graph_code_log.debug( + "%s", + lazy_format_graph_code( + f"pre insert_deferred_runtime_asserts {name}", gm, colored=True + ), + ) + + # We are going to mutate the dict + expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {} + placeholders = set() + first_non_placeholder = None + for node in graph.nodes: + if node.op != "placeholder": + first_non_placeholder = node + break + else: + placeholders.add(node) + + def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: + """ + If a size/stride/storage offset call on an intermediate tensor, + we can try to compute the value from input shapes instead. + """ + return ( + (val := _get_sym_val(node)) is not None + and not isinstance(val, sympy.Number) + # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported + and not _has_uninterpretable_sympy_function(val) + and any( + isinstance(arg, fx.Node) + and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) + and arg.op != "placeholder" + for arg in node.args + ) + ) + + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[Dict[str, Any]] = None, + ) -> None: + fake_args = [ + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + try: + node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] + except NotImplementedError: + # This can happen when attempting to reify a symbol with an unsupported call_function node, + # e.g. with NestedTensors + sym_size.int via match_symbol(). + # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. + pass + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + + # Track asserts/checks we've added + added_asserts: Set[sympy.Expr] = set() + constrained_unbacked_symbols: Set[sympy.Symbol] = set() + + def _sympy_interp(expr_to_proxy, expr): + # sympy_interp() with hash consing + from sympy import Integer, Number, Symbol + from sympy.logic.boolalg import BooleanAtom + + from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp + + # hash cons + if expr in expr_to_proxy: + return expr_to_proxy[expr] + # base cases, don't cache + if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)): + return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr) + + # hash cons on arguments, run expr handler + expr_to_proxy[expr] = _run_sympy_handler( + PythonReferenceAnalysis, + [_sympy_interp(expr_to_proxy, arg) for arg in expr.args], + expr, + ) + return expr_to_proxy[expr] + + def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool: + # This is probably unnecessary, but since torch._check() calls for single-symbol bounds + # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant + # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial. + if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan): + return False + lhs, rhs = expr.args + return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or ( + isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number) + ) + + def add_runtime_asserts(ras): + for ra in ras: + if ( + # redundant + ra.expr in added_asserts + # if we've already added a constrain_range call for this symbol, + # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant. + or ( + len(ra.expr.free_symbols) == 1 + and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols + and _is_bound_expr_for_symbol(ra.expr) + ) + # don't try to reify sympy functions we can't turn into FX nodes + or _has_uninterpretable_sympy_function(ra.expr) + ): + continue + + log.debug("inserting runtime assert %s", ra.expr) + # Need to process ALL free symbols, not just unbacked ones + fvs = free_symbols(ra.expr) + missing = fvs - expr_to_proxy.keys() + if missing: + i1 = min(missing, key=str) + # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 + # assert shape_env.is_unbacked_symint(i1), i1 + ras_by_symbol.setdefault(i1, []).append(ra) + else: + # Convert the sympy expression into a sequence of FX + # nodes + with _set_node_metadata_hook(gm, _node_metadata_hook): + res = _sympy_interp(expr_to_proxy, ra.expr).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Runtime assertion failed for expression {ra.expr} on node '{res}'", + ), + ) + added_asserts.add(ra.expr) + + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + # Placeholders can match symbols, but when we destructure them + # with size we have to make sure we insert the nodes after all + # the placeholders + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Unfortunately, this logic still must remain because manual + # make_fx calls may not explicitly bind all symbolic ints as + # arguments to the function, so we must infer it from the other + # arguments + if ( + node in placeholders + and (example_value := _get_example_value(node)) is not None + ): + + def match_symbol(symint, cb): + if ( + isinstance(symint, torch.SymInt) + and isinstance(symint.node, SymNode) + and isinstance(s := symint.node.expr, sympy.Symbol) + and s not in expr_to_proxy + ): + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb()) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + match_symbol(example_value, lambda: node) + if isinstance(t := example_value, torch.Tensor): + for i, s in enumerate(t.size()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_size.int, (node, i) + ), + ) + if not is_sparse_any(t): + for i, s in enumerate(t.stride()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_stride.int, (node, i) + ), + ) + match_symbol( + t.storage_offset(), + lambda: graph.call_function( + torch.ops.aten.sym_storage_offset.default, (node,) + ), + ) + + # Handle asserts that aren't associated with any symbol. This + # doesn't really have to be in the loop as it will only run once, + # it just needs to happen right after the placeholders. + # insert this after placeholders & added sym nodes, and before non-placeholders. + if node == first_non_placeholder: + add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] + + # deduplicate asserts already present in graph + if node.target in ( + torch._check, + torch.ops.aten._assert_scalar.default, + ): + if ( + node.args[0] == True # noqa: E712 + or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + or ( + assert_expr is not None + and _is_bound_expr_for_symbol(assert_expr) + ) + ): + arg = node.args[0] + gm.graph.erase_node(node) + if isinstance(arg, fx.Node) and not arg.users: + gm.graph.erase_node(arg) + else: + added_asserts.add(assert_expr) # type: ignore[arg-type] + + # hash cons, replace function calls that return torch.SymInts with direct references to + # FX nodes built up to reify the sympy expression. + if ( + node.op != "placeholder" + and (sym_expr := _get_sym_val(node)) is not None + ): + # this guards against deleting calls like item() that produce new untracked symbols + new_untracked_symbols = sym_expr.free_symbols - expr_to_proxy.keys() + # this guards against deleting calls that produce unbacked bindings we haven't yet seen. + # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint + # (is backed), but produces an unbacked symbol. In this case keep the node alive. + new_unbacked_bindings = ( + resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings", {}) + ).keys() + - expr_to_proxy.keys() + ) + + # maybe re-reify expression, replace current node + if ( + sym_expr in expr_to_proxy + or ( # example value is redundant + _is_intermediate_tensor_sym_call(node) + # shape call on intermediate tensor, turn into computation on input shapes + and not new_untracked_symbols + ) + ) and not new_unbacked_bindings: + if _is_intermediate_tensor_sym_call( + node + ): # reify from input shapes + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + # won't try DCE-ing tensor compute here + hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] + node.replace_all_uses_with(hash_node) + gm.graph.erase_node(node) + log.debug( + "CSE node %s -> %s for expr %s", node, hash_node, sym_expr + ) + + # store node in hash cons, don't delete/replace + elif sym_expr not in expr_to_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): # don't hash cons primitives + expr_to_proxy[sym_expr] = fx.Proxy(node) # type: ignore[arg-type] + + # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained, + # so calls before that are redundant. + if node.target in ( + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + ): + gm.graph.erase_node(node) + + defs = [] + + # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as + # equivalent, but the refinement calls we perform in this pass may struggle with associating the two. + # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough + # information about the old symbol when we re-export, raising errors on data-dependent guards. + # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is. + if unbacked_bindings := resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings") + ): + for s, keypath in unbacked_bindings.items(): + defs.append(s) + + # TODO: some CSE when generating these nodes can probably + # help reduce graph size and improve compile time + def go(node, keypath): + if keypath == (): + return node + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + if keypath[0].name == "size": + return go( + graph.call_function( + torch.ops.aten.sym_size.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + if keypath[0].name == "stride": + return go( + graph.call_function( + torch.ops.aten.sym_stride.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + return go( + graph.call_method( + keypath[0].name, (node, keypath[1].idx) + ), + keypath[2:], + ) + elif isinstance(keypath[0], CallMethodKey): + return go( + graph.call_method(keypath[0].name, (node,)), keypath[1:] + ) + elif isinstance(keypath[0], pytree.SequenceKey): + return go( + graph.call_function( + operator.getitem, (node, keypath[0].idx) + ), + keypath[1:], + ) + elif isinstance(keypath[0], ConvertIntKey): + return go( + graph.call_function( + cast_symbool_to_symint_guardless, (node,) + ), + keypath[1:], + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + return go( + graph.call_function( + operator.floordiv, (node, keypath[0].divisor) + ), + keypath[1:], + ) + elif isinstance(keypath[0], InnerTensorKey): + return go( + graph.call_function( + getattr, (node, keypath[0].inner_name) + ), + keypath[1:], + ) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + if s not in expr_to_proxy: + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(go(node, keypath)) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + for i0 in defs: + ras = ras_by_symbol.pop(i0, []) + # Before we perform any asserts, first apply range + # refinement. This is important, because if we are going + # to retrace the graph (and we typically are if we send + # the graph to AOTAutograd), we need to make sure we apply + # range refinement (ala _check_is_size) first, BEFORE we + # run any of the asserts. Otherwise, we may decide to + # perform substitutions based on the asserts which we then + # can't back out, because value ranges can only be applied + # to asserts.) + # + # A perhaps better long term plan is to avoid this order + # dependence by making it possible to refine ranges on + # arbitrary expressions, not just symbols. But it is not + # so easy to make use of this information, see + # https://twitter.com/ezyang/status/1745801370299482492 + # We actually made an attempt at this in + # https://github.com/pytorch/pytorch/pull/119043 + # which didn't work. + # + # Another ideas for how to do this: + # - Have bound_sympy be the source of truth of the ranges of any expression + # - Cache intermediate results for every subexpression of bound_sympy + # - This cache should be possible to edit to refine ranges + # + # One issue with this proposal is that if + # we have a bound on 2x, we are not going to be able to + # apply it for 4x. Similarly, we may have bounds for an + # equivalent expression that we are not applying because + # it's not a perfect match (e.g. x < y vs y > x)". + # + # The first issue we already have it and it's impossible + # to solve in general, so any implementation on a best + # effort basis should do. + # + # The second issue is a preexisting one. It can be mitigated + # with a normalisation algorithm. In general, it may also + # be on a best effort basis, but since our grammar is not + # terribly difficult, chances are we could even fully + # normalise SymPy expressions... who knows. + if i0 in constrained_unbacked_symbols: + continue # constrain symbol just once + + if i0 in shape_env.size_like: + if export: + graph.call_function( + torch.ops.aten.sym_constrain_range_for_size.default, + (expr_to_proxy[i0].node,), + ) + else: + graph.call_function( + torch._check_is_size, (expr_to_proxy[i0].node,) + ) + + vr = shape_env.var_to_range[i0] + if vr.is_int and vr.upper == sys.maxsize - 1: + # treat upper bound == sys.maxsize - 1 for int symbols as +oo + # to avoid redundant runtime assert + vr = ValueRanges(vr.lower, int_oo) + if not shape_env._default_unspecified_value_range().issubset(vr): + # The runtime range is constrained, so add a runtime + # assert and also explicitly refine the range + # (refinement should not be necessary once runtime + # asserts cause refinement, but that's NYI) + def convert(s): + if s in (int_oo, -int_oo): + return None + try: + return int(s) + except TypeError: + return None + + if ( + expr_to_proxy[i0].node.target + != cast_symbool_to_symint_guardless + ): + # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts + # raises AOTAutograd errors on cast_symbool_to_symint_guardless + + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) + + constrained_unbacked_symbols.add(i0) + add_runtime_asserts(ras) + + # delete unused reified symbols + for expr, proxy in expr_to_proxy.items(): + if ( + isinstance(expr, sympy.Symbol) + and proxy.node.op != "placeholder" # keep placeholders intact + and not proxy.node.users + ): + log.debug("deleting unused reified symbol for %s", expr) + gm.graph.erase_node(proxy.node) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaee3f82113907132ea852e9f74dc73a5bb10d9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/shape_prop.py @@ -0,0 +1,196 @@ +# mypy: ignore-errors + +import torch +import torch.fx +import traceback + +from torch._dispatch.python import enable_python_dispatcher +from torch.fx.node import Node, map_aggregate +from typing import Any, Tuple, NamedTuple, Optional, Dict +from torch.fx._compatibility import compatibility +from torch._guards import detect_fake_mode +from torch._subclasses.meta_utils import is_sparse_any + +__all__ = ['TensorMetadata', 'ShapeProp'] + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape : torch.Size + dtype : torch.dtype + requires_grad : bool + stride : Tuple[int, ...] + memory_format : Optional[torch.memory_format] + + # Quantization metadata + is_quantized : bool + qparams: Dict[str, Any] + +def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() if not is_sparse_any(result) else None + + memory_format = None + + if include_contiguity and not is_sparse_any(result): + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + for query_format in memory_formats: + if result.is_contiguous(memory_format=query_format): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: Dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + +@compatibility(is_backward_compatible=True) +class ShapeProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and + record the shape and type of the result + into the corresponding node. + + Example: + In this example, we record the shape + and data type of a module given + an example input ``torch.randn(50, D_in)``. + We print the name, shape and dtype of each node. + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super().__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + N, D_in, H, D_out = 64, 1000, 100, 10 + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + model = TwoLayerNet(D_in, H, D_out) + gm = torch.fx.symbolic_trace(model) + sample_input = torch.randn(50, D_in) + ShapeProp(gm).propagate(sample_input) + + for node in gm.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, + node.meta['tensor_meta'].shape) + + The output of this code is: + + x torch.float32 torch.Size([50, 1000]) + linear1 torch.float32 torch.Size([50, 100]) + clamp_1 torch.float32 torch.Size([50, 100]) + linear2 torch.float32 torch.Size([50, 10]) + output torch.float32 torch.Size([50, 10]) + + Args: + module (GraphModule): The module to be executed + fake_mode (FakeTensorMode): A fake mode for copying the gm + + """ + def __init__(self, gm, fake_mode=None): + super().__init__(gm) + if fake_mode is None: + fake_mode = detect_fake_mode() + if fake_mode is not None: + from torch._dynamo.utils import deepcopy_to_fake_tensor + # Note: + # We need fake execution cause the inputs are fake, however, we cannot fakify the module + # - because we need to write to the tensor_meta of the real module. So we fakify to + # produce a result (L131 below), to extract tensor meta, and then keep going. + # + # If we were to fakify, we would write to the wrong node, and then downstream fusion + # would be missing the tensor_meta. + # + # See torch/_inductor/overrides.py for where this is called upstream of fusion. + self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) + self.fake_mode = fake_mode + else: + self.fake_module = None + self.fake_mode = None + + self.real_module = self.module + + def run_node(self, n : Node) -> Any: + try: + if self.fake_module is not None: + # Hacky swap. Alternatively, we could do this with overriding + # call_module and get_attr. + self.module = self.fake_module + try: + if self.fake_mode is not None: + with self.fake_mode, enable_python_dispatcher(): + result = super().run_node(n) + else: + result = super().run_node(n) + finally: + self.module = self.real_module + except Exception as e: + traceback.print_exc() + raise RuntimeError( + f"ShapeProp error for: node={n.format_node()} with " + f"meta={n.meta}" + ) from e + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['tensor_meta'] = meta + + n.meta['type'] = type(result) + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + if self.fake_mode is not None: + fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] + else: + fake_args = args + return super().run(*fake_args) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..1881beaf2ece1362b8d89b83c5e8cf891bcecf75 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/split_module.py @@ -0,0 +1,575 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Any, Callable, Dict, List, Optional, Set +from collections import OrderedDict +import logging + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx._utils import lazy_format_graph_code + + +__all__ = ["Partition", "split_module"] +log = _LOGGER = logging.getLogger(__name__) + +@compatibility(is_backward_compatible=True) +class Partition: + def __init__(self, name: str): + self.name: str = name + self.submod_name = f"submod_{name}" + self.node_names: List[str] = [] + self.inputs: Dict[str, None] = {} + self.outputs: Dict[str, None] = {} + self.dependencies: Dict[str, None] = {} + self.dependents: Dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: Dict[Node, Node] = {} + self.targets: Dict[str, Any] = {} + + def __repr__(self) -> str: + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions depended on: {self.dependencies},\n" + f" partition dependents: {self.dependents}" + ) + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[Node], int], + qualname_map: Optional[Dict[str, str]] = None, + keep_original_order: Optional[bool] = False, + keep_original_node_name: Optional[bool] = False, +): + """ + Creates subgraphs out of main graph + + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a + mapping from new target names in the module after split to old target + names in the original module. + keep_original_order: Optional[bool]: keep the original order of the GraphModule + or use the Topological order of the new constructed GraphModule + + + Returns: + GraphModule: the module after split. + + Example: + + This is a sample setup: + + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from torch.fx.passes.split_module import split_module + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + + Output looks like this. Original graph is broken into partitions + + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + + Output of split module is the same as output of input traced module. + This is an example within a test setting: + + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + + log.debug( + "%s", + lazy_format_graph_code( + "pre split_module", m, colored=True + ), + ) + + def construct_graph( + node: Node, + base_mod_env: Dict[str, Node], + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], + ): + if node.op == "placeholder": + default_value = ( + node.args[0] if len(node.args) > 0 else inspect.Signature.empty + ) + if keep_original_node_name: + args = () if default_value is inspect.Signature.empty else (default_value,) + base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type] + else: + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type] + ) + base_mod_env[node.name].meta = node.meta.copy() + elif node.op == "get_attr": + base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type] + base_mod_env[node.name].meta = node.meta.copy() + attr_val = m + for atom in node.target.split("."): # type: ignore[union-attr] + if not hasattr(attr_val, atom): + raise AttributeError(f"Node target {node.target} not found!") + attr_val = getattr(attr_val, atom) + base_mod_attrs[node.target] = attr_val # type: ignore[index] + return base_mod_env, base_mod_attrs + + import sympy + + partitions: Dict[str, Partition] = {} + orig_nodes: Dict[str, Node] = {} + symbol_to_node: Dict[sympy.Symbol, Node] = {} + + def record_cross_partition_use( + def_node: Node, use_node: Optional[Node] + ): # noqa: B950 + from torch.fx.experimental.symbolic_shapes import free_symbols + + defined = getattr(def_node, "_fx_partition", None) + used = getattr(use_node, "_fx_partition", None) + + log.debug( + "record_cross_partition_use %s (%s) %s (%s)", + def_node.name, defined, use_node.name if use_node is not None else "-", used + ) + + if defined != used: + if defined is not None: + def_partition = partitions[defined] + def_partition.outputs.setdefault(def_node.name) + if used is not None: + def_partition.dependents.setdefault(used) + + if used is not None: + use_partition = partitions[used] + use_partition.inputs.setdefault(def_node.name) + # We have made def_node an input to the use_partition. If + # this input has symbolic symbols in its size, those also must + # be made as inputs to the partition + if (def_val := def_node.meta.get("example_value")) is not None: + for s in sorted(free_symbols(def_val), key=str): + s_node = symbol_to_node[s] + use_partition.inputs.setdefault(s_node.name) + if symbol_to_node[s].op != "placeholder": + # If the node that defines the symbol is not a + # placeholder, we must make it an output of the + # partition. Note that this may be in a different + # partition than defined! Although, this doesn't + # really make a difference for correctness, since + # defined is guaranteed to have the symbol in + # scope and can return it; you just get less + # optimal codegen in this case. + s_defined = getattr(s_node, "_fx_partition", None) + if s_defined is not None: + s_def_partition = partitions[s_defined] + s_def_partition.outputs.setdefault(s_node.name) + s_def_partition.dependents.setdefault(used) + if defined is not None: + use_partition.dependencies.setdefault(defined) + + def instantiate_node_partition_mapping(node): + partition_name = str(split_callback(node)) + log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + # Global State Nodes are nodes which by their global state effects, + # "taint" all downstream nodes while they are active. + GLOBAL_STATE_NODES = [ + torch.amp._enter_autocast, + torch.amp._exit_autocast, + torch._C._set_grad_enabled + ] + + # For grad regions: + # ------------------------ + # 1. first region: we do nothing + # 2. subsequent regions: we insert the set_grad at the beginning + grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() + + # For autocast regions: + # ------------------------ + # 1. first region: we will only insert the _exit at the end + # 2. intermediate regions: we will insert both the + # _enter at the beginning and _exit at the end + # 3. last region: we will only insert _enter at the beginning + # We will do so in the order in which the autocasts were instantiated. + autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() + autocast_exits: Dict[Node, Optional[Node]] = {} + + active_grad = None + active_autocasts = set() + + for node in m.graph.nodes: + # This will prefer placeholder bindings, because those come first. + # This is a little dangerous though: it is possible that an unbacked + # symbol is used without any binding site for it, in which case we + # will get a KeyError not able to find it. I'd like to fix this by + # having passes.runtime_assert establish some invariants that I can + # rely on later, but this needs some extra work. Quick fix first. + # See https://github.com/pytorch/pytorch/issues/130534 + if ( + (val := node.meta.get("example_value")) is not None and + isinstance(val, torch.SymInt) and + isinstance(s0 := val.node.expr, sympy.Symbol) and + s0 not in symbol_to_node + ): + symbol_to_node[val.node.expr] = node + + if node.op in ["placeholder", "get_attr", "output"]: + continue + + instantiate_node_partition_mapping(node) + + if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: + if node.target == torch._C._set_grad_enabled: + assert len(node.args) == 1 + assert isinstance(node.args[0], bool) + active_grad = node + grad_regions[active_grad] = set({split_callback(node)}) + elif node.target == torch.amp._enter_autocast: + # Should all be python constants + assert all(not isinstance(arg, Node) for arg in node.args) + active_autocasts.add(node) + autocast_regions[node] = set({split_callback(node)}) + autocast_exits[node] = None + elif node.target == torch.amp._exit_autocast: + assert len(node.args) == 1 + autocast_regions[node.args[0]].add(split_callback(node)) + active_autocasts.remove(node.args[0]) + autocast_exits[node.args[0]] = node + + if active_grad is not None: + grad_regions[active_grad].add(split_callback(node)) + + for a in active_autocasts: + autocast_regions[a].add(split_callback(node)) + + assert all(v is not None for v in autocast_exits.values()), "autocast must exit" + + autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} + grad_regions = {k: sorted(v) for k, v in grad_regions.items()} + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("autocast_regions: %s", autocast_regions) + _LOGGER.debug("grad_regions: %s", grad_regions) + + assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) + + # split nodes into partitions + highest_partition = -1 + for node in m.graph.nodes: + orig_nodes[node.name] = node + + # TODO currently placeholders/parameters aren't put into random partitions, + # rather they're added to the graphs where they are used down below + if node.op in ["placeholder", "get_attr"]: + continue + if node.op == "output": + torch.fx.graph.map_arg( + node.args[0], lambda n: record_cross_partition_use(n, None) + ) + continue + + if assert_monotonically_increasing: + pid = split_callback(node) + assert highest_partition <= pid, \ + ("autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}") + highest_partition = pid + + # do not capture cross-partition dependencies for global state nodes as they will be + # self-contained - their setup and unwind will be isolated to each partition submodule. + if node.target not in GLOBAL_STATE_NODES: + torch.fx.graph.map_arg( + node.args, lambda def_node: record_cross_partition_use(def_node, node) + ) + torch.fx.graph.map_arg( + node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) + ) # noqa: B950 + + original_partition_order = list(partitions.keys()) + # find partitions with no dependencies + root_partitions: List[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.dependencies): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: List[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].dependents: + partitions[dependent].dependencies.pop(root_partition) + if not partitions[dependent].dependencies: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # Enter prelude + for regions_mapping in [autocast_regions, grad_regions]: + for node, regions in regions_mapping.items(): + assert len(regions) > 0 + partitions[str(regions[0])].environment[node] = node + for r in regions[1:]: + partition = partitions[str(r)] + new_node = partition.graph.create_node( + op=node.op, + target=node.target, + args=tuple(arg for arg in node.args), + kwargs={}, + type_expr=node.type, + ) + new_node.meta = node.meta.copy() # is it really a good idea to copy this? + partition.environment[node] = new_node + + # add placeholders to partition inputs + for partition_name in sorted_partitions: + partition = partitions[partition_name] + for inp in partition.inputs: + placeholder = partition.graph.placeholder( + inp, + type_expr=orig_nodes[inp].type, + ) + placeholder.meta = orig_nodes[inp].meta.copy() + partition.environment[orig_nodes[inp]] = placeholder + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, "_fx_partition"): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg( + node.kwargs, lambda n: environment[n] + ) + + if node.op not in ["call_module", "get_attr"]: + target = node.target + else: + target_atoms = node.target.split(".") + target_attr = m + for atom in target_atoms: + if not hasattr(target_attr, atom): + raise AttributeError(f"Operator target {node.target} not found!") + target_attr = getattr(target_attr, atom) + # target = target_atoms[-1] + target = "_".join(target_atoms) + partition.targets[target] = target_attr + # Fill in the passed-in mapping from new qualname to old qualname + if qualname_map is not None: + # When creating the split module later, the submodules will have + # path prefix matching the corresponding partition's submod_name + qualname = f"{partition.submod_name}.{target}" + qualname_map[qualname] = node.target + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + name = node.name if keep_original_node_name else None + new_node = partition.graph.create_node( + op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + type_expr=node.type, + name=name, + ) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Exit epilogue + for regions_mapping in [autocast_regions]: + for node in reversed(regions_mapping): + regions = regions_mapping[node] + assert len(regions) > 0 + for r in regions[:-1]: + partition = partitions[str(r)] + exit_node = autocast_exits[node] + assert exit_node is not None, "Missing exit node" + new_node = partition.graph.create_node( + op=exit_node.op, + target=exit_node.target, + args=(partition.environment[node],), + kwargs={}, + type_expr=exit_node.type, + ) + new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? + + # original module environment dict mapping node names to nodes + orig_mod_env: Dict[str, Node] = {} + # Set up values to construct base module + base_mod_env: Dict[str, Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} + if not keep_original_order: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + else: + # Go through the graph to construct the mapping dict + for node in m.graph.nodes: + orig_mod_env[node.name] = node + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order or original order specified by keep_original_order + + construct_order_partitions = ( + sorted_partitions if not keep_original_order else original_partition_order + ) + + already_constructed_attr_nodes = set() + + # We actually need to insert the placeholder nodes in the original order + # otherwise graph signature will be wrong. + original_order = [node for node in m.graph.nodes if node.op == "placeholder"] + + for partition_name in construct_order_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple( + partition.environment[orig_nodes[name]] for name in partition.outputs + ) + + # skip output node generation if there are no output values + num_output_vals = len(output_vals) + if num_output_vals == 1: + partition.graph.output(output_vals[0]) + elif num_output_vals > 1: + partition.graph.output(output_vals) + + if keep_original_order: + # first get the attr nodes required by this partition + orig_mod_attr_nodes: List[Node] = [ + orig_mod_env[key] for key in partition.inputs if key not in original_order + ] + + for node in original_order: + if node in already_constructed_attr_nodes: + continue # already added this attr to the base graph + base_mod_env, based_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + # Construct GraphModule for this partition + for node in orig_mod_attr_nodes: # type: ignore[attr-defined] + if node in already_constructed_attr_nodes: + continue + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module( + partition.submod_name, + tuple(base_mod_env[name] for name in partition.inputs), + ) + + num_outputs = len(partition.outputs) + if num_outputs > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + elif num_outputs == 1: + base_mod_env[next(iter(partition.outputs))] = output_val + + for node in m.graph.nodes: + if node.op == "output": + base_mod_graph.output( + torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) + ) # noqa: B950 + + ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + log.debug( + "%s", + lazy_format_graph_code( + "post split_module", ret, colored=True + ), + ) + return ret diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c003966983f3f90c1620a2acfe9742cfb4f6fa9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/split_utils.py @@ -0,0 +1,303 @@ +# mypy: allow-untyped-defs +import copy +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import map_arg +from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module + +from .tools_common import NodeList + +__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] + + +@compatibility(is_backward_compatible=False) +def getattr_recursive(obj, name): + for layer in name.split("."): + if hasattr(obj, layer): + obj = getattr(obj, layer) + else: + return None + return obj + + +@compatibility(is_backward_compatible=False) +def setattr_recursive(obj, attr, value): + if "." not in attr: + setattr(obj, attr, value) + else: + layer = attr.split(".") + setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) + + +@compatibility(is_backward_compatible=False) +@dataclass +class Component: + """ + A component serves as a container for a subgraph we want to create afterwards. + """ + + graph: torch.fx.Graph + order: int + name: str + + # Stores the placeholder nodes in `graph`. + input_placeholders: List = field(default_factory=list) + + # Store the nodes in original graph that are placeholder in `graph`. + orig_inputs: List = field(default_factory=list) + + # Store the nodes in original graph that are outputs in `graph`. + orig_outputs: List = field(default_factory=list) + + # Mapping from get_attr node in original graph to get_attr node in `graph`. + getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) + constructor_args: List[str] = field(default_factory=list) + gm: Optional[torch.fx.GraphModule] = None + + +@compatibility(is_backward_compatible=False) +def split_by_tags( + gm: torch.fx.GraphModule, + tags: List[str], + return_fqn_mapping: bool = False, + return_tuple: bool = False, + GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, +) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: + """ + Splits a GraphModule using tags on its graph nodes. We honor the order of + tags. For example, we have tags = ["a", "b", "c"], the function will create + the initial submodules in the order of "a", "b", "c". + + To set a tag: + gm.graph.nodes[idx].tag = "mytag" + + This will result in all nodes with the same tag being extracted and placed in their + own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder + and output nodes are created when needed while get_attr nodes get copied to submodules + where they are used. + + Given the following module def: + + class SimpleModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(...) + self.linear2 = torch.nn.Linear(...) + self.linear3 = torch.nn.Linear(...) + + def forward(self, in1, in2): + r1 = self.linear1(in1) + r2 = self.linear2(in2) + r3 = torch.cat([r1, r2]) + return self.linear3(r3) + + Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: + + ro: + def forward(self, in1): + self = self.root + linear1 = self.linear1(in1) + return linear1 + + main: + def forward(self, in2, linear1): + self = self.root + linear2 = self.linear2(in2) + cat_1 = torch.cat([linear1, linear2]) + linear3 = self.linear3(cat_1) + return linear3 + + main: + def forward(self, in1, in2): + self = self.root + ro_0 = self.ro_0(in1) + main_1 = self.main_1(in2, ro_0) + return main_1 + + Returns: + split_gm: torch fx graph after split + orig_to_split_fqn_mapping: a map between the original fqn and the fqn + after split for call_module and get_attr. + """ + + def flatten(x: torch.fx.node.Argument) -> NodeList: + """ + Stores nodes in x to a list and returns the list. + """ + r: NodeList = [] + map_arg(x, r.append) + return r + + # Mapping from node in original module to node in created submodule. + node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + + # Mapping from node in original module or created submodules to + # corresponding component. + node_to_component: Dict[torch.fx.Node, Component] = {} + + # Mapping from tag to the corresponding component. + tag_to_component: Dict[str, Component] = {} + + # Stores all components. + all_components: List[Component] = [] + + # Stores nodes that will be used in main graph. + used_in_main: Dict[torch.fx.Node, None] = {} + + # Main graph after split. + main_g = torch.fx.Graph() + + # Mapping from node in original module to node in main graph after split. + main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} + + # Output node of original module. + output_node: Optional[torch.fx.Node] = None + + # Create a component for each tag, we don't expect to create other components afterwards. + for tag in tags: + comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") + all_components.append(comp) + tag_to_component[tag] = comp + + # Traverse the nodes in original graph and take care of them. + for node in gm.graph.nodes: + if node.op == "output": + if output_node is not None: + raise RuntimeError("Multiple output nodes in graph!") + output_node = node + continue + + # Placeholders in the original graph get copied to main graph. + if node.op == "placeholder": + main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) + main_remapping[node].meta = copy.copy(node.meta) + continue + + # Get_attr nodes are ignored because we are not tagging them. + # Instead, we copy them directly to the submodules use them afterwards. + if node.op == "get_attr": + continue + + # Now we process callable nodes which are nodes with op of call_module, + # call_function or call_method. Every callable nodes should be tagged. + assert hasattr(node, "tag") + + upstream_components = [ + node_to_component[x] + for x in flatten(node.args) + flatten(node.kwargs) + if x.op not in {"placeholder", "get_attr"} + ] + + comp = tag_to_component[node.tag] + node_to_component[node] = comp + + # Max order of upperstream components. + mx = max((c.order for c in upstream_components), default=0) + + # Expect the component for `node` has higher order then its upstream components. + assert comp.order >= mx + + # Map a input of `node` to nodes in the component's graph. + def remap_func(x): + # If input is a get_attr node, copy it to current component's graph. + # Returns the get_attr node in current component's graph. + if x.op == "get_attr": + if x not in comp.getattr_maps: + comp.getattr_maps[x] = comp.graph.get_attr( + x.target, type_expr=x.type + ) + return comp.getattr_maps[x] + + # If input is not a placeholder, it should have been put into a component + # already. If it's the current component then we return the corresponding + # node in the component. + if x.op != "placeholder" and node_to_component[x] == comp: + return node_remapping[x] + + # If input is a placeholder or it's in other components, we want to make it + # as a placeholder in current component's graph. + if x not in comp.orig_inputs: + comp.orig_inputs.append(x) + placeholder = comp.graph.placeholder(x.name, type_expr=x.type) + placeholder.meta = copy.copy(x.meta) + comp.input_placeholders.append(placeholder) + used_in_main[x] = None + + return comp.input_placeholders[comp.orig_inputs.index(x)] + + n = comp.graph.node_copy(node, remap_func) + n.tag = node.tag # type: ignore[attr-defined] + node_remapping[node] = n + node_to_component[n] = comp + + if output_node is None: + raise RuntimeError("Graph had no output node!") + + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + # We don't need components mapping for nodes of type "get_attr" + # that are consumed by the output. Only need to make sure we create + # corresponding counterparts in the resulting graph. + main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) + else: + # All component results consumed by the output node should be + # marked as "used in main". + used_in_main[x] = None + + # If a node is used in main graph then we mark it as an output in the component + # it belongs to. + for n in used_in_main: + if n.op != "placeholder": + node_to_component[n].orig_outputs.append(n) + + # Now we create a graphmodule for each component. + orig_to_split_fqn_mapping: Dict[str, str] = {} + for comp in all_components: + outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) + + if return_tuple: + comp.graph.output(outs) + else: + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + comp.graph.output(outs[0] if len(outs) == 1 else outs) + + comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( + gm, subgraph=comp.graph, comp_name=comp.name + ) + orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) + + # Create a call_module node in main graph. + main_node = main_g.call_module( + comp.name, + args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), + kwargs=None, + ) + + if len(outs) == 1 and not return_tuple: + main_remapping[comp.orig_outputs[0]] = main_node + else: + for i, o in enumerate(comp.orig_outputs): + # Use Proxy to record getitem access. + main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] + + main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) + main_root = HolderModule({comp.name: comp.gm for comp in all_components}) + main_g._codegen = gm.graph._codegen + + # If the output nodes consumes get_attr directly in the original graph, + # then we need to make sure get_attr is copied to the new graph. + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] + + result_gm = GraphModuleCls(main_root, main_g) + if return_fqn_mapping: + return result_gm, orig_to_split_fqn_mapping + + return result_gm diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..70b117c8ca374f33585901831b5435e5c3d9fbcc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/splitter_base.py @@ -0,0 +1,898 @@ +# mypy: allow-untyped-defs +import argparse +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple +import logging + +import torch +from torch.fx.passes.graph_manipulation import get_size_of_node +from torch.fx.node import map_arg +from torch.fx._compatibility import compatibility + +from .operator_support import ( + get_node_target, + OperatorSupportBase, +) +from .graph_drawer import FxGraphDrawer +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + FxNetAccFusionsFinder, + CALLABLE_NODE_OPS, + Tensors, + NodeList, + NodeSet, + is_node_output_tensor, +) + + +__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] +_LOGGER = logging.getLogger(__name__) + +DEFAULT_MIN_ACC_MODULE_SIZE = 1 +DEFAULT_SKIP_FUSION = False +DEFAULT_ALLOW_NON_TENSOR = False + +class _SplitterSettingBase: + def __init__( + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion=DEFAULT_SKIP_FUSION, + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, + ): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-acc-module-size", + "--min_acc_module_size", + required=False, + type=int, + help="Minimum size limit of an accelerator subgraph.", + ) + parser.add_argument( + "--max-acc-splits", + "--max_acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) + parser.add_argument( + "--skip-fusion", + "--skip_fusion", + default=False, + action="store_true", + help="If true then no fusion groups. Fusion group is used to " + "enforce no non-tensor data flow between submodules. If we don't " + "have this constrain, setting this to false is recommended as it " + "can reduce overhead.", + ) + parser.add_argument( + "--allow-non-tensor", + "--allow_non_tensor", + default=False, + action="store_true", + help="For some backends non-tensor data flow between cpu and them " + "are not allowed. Therefore, if a node supported by accelerator but " + "it has non-tensor inputs or outputs to a cpu node we would want to " + "consider it as a cpu node during splitting. However, for some backends " + "we might not care about non-tensor data flow and we can set this option " + "to true to disable the functionality that prevent non-tensor data flow.", + ) + args, unknown = parser.parse_known_args() + + self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size + self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion + self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + self.max_acc_splits: int = max_acc_splits + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + self.acc_nodes: NodeSet = set() + + def reduce_acc_nodes_non_tensor_input_helper( + self, cpu_worklist: NodeList + ): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + self.acc_nodes = { + n + for n in self.module.graph.nodes + if n.op in CALLABLE_NODE_OPS + and self.operator_support.is_node_supported(submodules, n) + } + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + +@compatibility(is_backward_compatible=False) +class FxNetSplitterInternalError(Exception): + pass + +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + nodes: NodeList + device_ordinal: Optional[int] = None + +@compatibility(is_backward_compatible=False) +class SplitResult(NamedTuple): + """ + Stores the results of the splitter. + + Attributes: + split_module: root module after splitting. + submodule_inputs: a dict that maps submodule name to its inputs. + non_acc_submodule_prefix: the prefix for non acc submodules. For + acc submodule the prefix is alwasy "_run_on_acc_". + """ + + split_module: torch.fx.GraphModule + submodule_inputs: Dict[str, Any] + non_acc_submodule_prefix: str + + +@compatibility(is_backward_compatible=False) +def generate_inputs_for_submodules( + model: torch.nn.Module, + inputs: Sequence[Any], + target_submodules: Iterable[str], + deepcopy: bool = False, +) -> Dict[str, Any]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_inputs): + results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs + + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append(mod.register_forward_pre_hook(pre_forward)) + + def clean_up_handles(): + for h in handles: + h.remove() + + try: + with torch.no_grad(): + model(*inputs) + except Exception as e: + clean_up_handles() + raise e + + clean_up_handles() + return results + + +class _SplitterBase: + """ + Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. + Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. + Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. + + Given the following graph: + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, + we will get the following split result: + + main: + def forward(self, a): + run_on_acc_0_0 = self._run_on_acc_0_0(a) + getitem = run_on_acc_0_0[0] + getitem_1 = run_on_acc_0_0[1] + run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) + return run_on_cpu_1_1 + + _run_on_acc_0_0: + def forward(self, a): + sin_1 = torch.sin(a) + cos_1 = torch.cos(a) + return (sin_1, cos_1) + + _run_on_cpu_1_1: + def forward(self, sin_1, cos_1): + add_1 = sin_1 + cos_1 + return add_1 + """ + + # PCIe bandwidth for the backend, default to 100 GB/s + PCIe_BW = 100 * 2 ** 30 + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: OperatorSupportBase, + settings: _SplitterSettingBase, + non_acc_submodule_name: str = "_run_on_cpu_", + return_tuple: bool = False, + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + ShapeProp(self.module).propagate(*sample_input) + + self.settings = settings + self.operator_support = operator_support + self.sample_input = sample_input + self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = non_acc_submodule_name + self._node_submodule_map: Dict[str, str] = {} + self._return_tuple = return_tuple + + self.tags: List[str] = [] + + # =============================================================== + # Helpers for ctor and initial state + # =============================================================== + + def get_node_submodule_map(self) -> Dict[str, str]: + """ Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 + """ + return self._node_submodule_map + + def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: + """ + Builds a graph of node dependencies. Leaf nodes don't have any + dependencies and the "output" node doesn't have nodes depending on it. + + Resulting graph has only direct dependencies, i.e. there are no + transitive dependencies. + """ + deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op != "output": + deps[user].add(node) + return deps + + def update_deps_for_fusions(self): + """ + Updates graph of dependencies so that: + - nodes from the same fusion depend on the same set of outer nodes, + - outer nodes depending on a fusion depend on all nodes in that fusion. + """ + for node in self.fusions: + fusion = self.fusions[node] + for fused_neighbor in fusion: + self.deps[node].update(self.deps[fused_neighbor] - fusion) + + for user in fused_neighbor.users: + if user not in fusion: + self.deps[user].add(node) + + # =============================================================== + # Helpers for preview + # =============================================================== + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> torch.nn.Module: + """ + Lower the model to a backend. + """ + + return mod + + def _find_culprit( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> str: + """ + When an error occurs during lowering or running the lowered mod, we use this + function to find culprits in the `mod` that causes the error. + """ + + return "Unable to find a culprit because _find_culprit() function is not implemented." + + def _draw_graph_based_on_node_support( + self, mod: torch.fx.GraphModule, supported_nodes: NodeList + ): + color_map = { + "default": "AliceBlue", + "supported": "chartreuse1", + "unsupported": "crimson", + } + + class CustomDrawer(FxGraphDrawer): + def _get_node_style(self, node): + template = super()._get_node_style(node) + if node in supported_nodes: + template["fillcolor"] = color_map["supported"] + elif node.op in CALLABLE_NODE_OPS: + template["fillcolor"] = color_map["unsupported"] + else: + template["fillcolor"] = color_map["default"] + + return template + + drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) + dot_graph = drawer.get_main_dot_graph() + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw("node_support.dot") + + def node_support_preview(self, dump_graph: bool = False): + submodules = dict(self.module.named_modules()) + + supported_nodes: NodeList = [] + supported_node_types = defaultdict(set) + unsupported_node_types = defaultdict(set) + + def get_dtype(arg): + tensor_meta = arg.meta.get("tensor_meta") + return getattr(tensor_meta, "dtype", None) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + target = get_node_target(submodules, node) + + # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. + arg_dtypes = [ + get_dtype(arg) if isinstance(arg, torch.fx.Node) else None + for arg in node.args + ] + + # Find last non-None element. If all elements are None, return max_len. + last_index = len(arg_dtypes) - next( + ( + i + for i, dtype in enumerate(reversed(arg_dtypes)) + if dtype is not None + ), + len(arg_dtypes), + ) + + # Strip None elements at the end. + arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) + kwarg_dtypes_tuple = tuple( + (k, get_dtype(arg)) + for k, arg in node.kwargs.items() + if isinstance(arg, torch.fx.Node) + ) + + if self.operator_support.is_node_supported(submodules, node): + supported_nodes.append(node) + supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + else: + unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + + if dump_graph: + self._draw_graph_based_on_node_support(self.module, supported_nodes) + + reports = "\nSupported node types in the model:\n" + for t, dtypes in supported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + reports += "\nUnsupported node types in the model:\n" + for t, dtypes in unsupported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + print(reports) + + # Return reports for testing purpose + return reports + + def split_preview(self, dump_graph: bool = False): + reports = "" + subgraphs = self.put_nodes_into_subgraphs() + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + for i, subgraph in enumerate(subgraphs): + reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " + reports += f"{len(subgraph.nodes)} node(s)\n" + + self.tag(subgraphs) + split_mod = self.split(remove_tag=True) + split_mod.eval() + + if dump_graph: + drawer = FxGraphDrawer( + split_mod, "preview", ignore_getattr=True + ) + dot_graphs = drawer.get_all_dot_graphs() + for name, dot_graph in dot_graphs.items(): + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw(f"{name}.dot") + + max_qps: float = self.PCIe_BW + bottleneck_module = "" + + for node in split_mod.graph.nodes: + if node.op == "call_module" and "acc" in node.target: + reports += f"\nProcessing acc submodule {node.target}\n" + + submod = getattr(split_mod, node.target) + + def get_submod_inputs(main_mod, submod, example_inputs): + sub_inputs = None + + def get_inputs(self, inputs): + nonlocal sub_inputs + sub_inputs = inputs + + handle = submod.register_forward_pre_hook(get_inputs) + main_mod(*example_inputs) + handle.remove() + return sub_inputs + + submod_inputs = get_submod_inputs( + split_mod, submod, self.sample_input + ) + ShapeProp(submod).propagate(*submod_inputs) + + total_input_bytes = 0 + total_output_bytes = 0 + + reports += "Checking inputs...\n" + for n in submod.graph.nodes: + if n.op == "placeholder": + if not is_node_output_tensor(n): + reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_input_bytes += get_size_of_node(submod, n)[0] + if n.op == "output": + output_node = n + + reports += "Checking outputs...\n" + + def get_bytes(node: torch.fx.Node): + nonlocal total_output_bytes + nonlocal reports + if not is_node_output_tensor(node): + reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_output_bytes += get_size_of_node(submod, node)[0] + + map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] + qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) + reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," + reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" + + if qps < max_qps: + max_qps = qps + bottleneck_module = node.target + + try: + lowered_submod = self._lower_model_to_backend(submod, submod_inputs) + except RuntimeError: + reports += "Run into an error during lowering!\n" + reports += self._find_culprit(submod, submod_inputs) + continue + + try: + lowered_submod(*submod_inputs) + except RuntimeError: + reports += "Run into an error during inference!\n" + reports += self._find_culprit(submod, submod_inputs) + else: + reports += "Lowering and running succeed!\n" + + reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," + reports += f" bottleneck is submodule {bottleneck_module}." + print(reports) + + # return the reports for testing purposes + return reports + + # =============================================================== + # Helpers for extend_acc_subgraph() method + # =============================================================== + + def find_reverse_deps( + self, tag_id: Optional[int] = None + ) -> Dict[torch.fx.Node, NodeSet]: + """ + Builds reversed topological node dependencies, if tag_id is specified, + we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. + """ + result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): + result[node].add(user) + + return result + + def update_reverse_deps_for_fusions( + self, deps: Dict[torch.fx.Node, NodeSet] + ): + processed_node = set() + + for node, fusion in self.fusions.items(): + if node in processed_node: + continue + + new_dep = set() + + # Create a new dependency set which include all the + # dependencies of the nodes in the fusion group + for n in fusion: + new_dep.update(deps[n]) + + # Exclude nodes in the fusion + new_dep.difference_update(fusion) + + # Update dependency + for n in fusion: + deps[n] = new_dep + + for arg in n.all_input_nodes: + if arg not in fusion: + deps[arg].update(fusion) + + processed_node.add(n) + + def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: + """ + Finds parent nodes of the `tag` subgraph. + + Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph + and is not a placeholder, we consider it as the parent node of the subgraph. + """ + parent_nodes = set() + + for node in self.module.graph.nodes: + if node.op in CALLABLE_NODE_OPS and node.tag == tag: + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: + parent_nodes.add(arg) + + return parent_nodes + + def extend_acc_subgraph(self, tag: str): + """ + Extend the acc subgraph with `tag` going the reversed topological direction. + """ + # Dict that maps node to its users and ignore users that + # are in the subgraph that has greater tag + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + self.update_reverse_deps_for_fusions(deps) + + # Parent nodes of the subgraph + parent_nodes = self.find_parent_nodes_of_subgraph(tag) + + visited_nodes: NodeSet = set() + + while parent_nodes: + node = None + + # Find a acc node that depends on visited nodes only + for n in parent_nodes: + if deps[n] <= visited_nodes and n in self.acc_nodes: + node = n + break + + if node is None: + break + + # Put the node into `tag` subgraph + node.tag = tag # type: ignore[attr-defined] + parent_nodes.remove(node) + visited_nodes.add(node) + + # If node is in a fusion group, add all fusion buddies to parent nodes + if node in self.fusions: + for fusion_node in self.fusions[node]: + if fusion_node not in visited_nodes: + parent_nodes.add(fusion_node) + + # Add inputs of the node to parent nodes + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: + parent_nodes.add(arg) + + # =============================================================== + # Helpers for split() method + # =============================================================== + + def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + return starter_cpu_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> List[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: List[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If nothing was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent CPU submodules. + """ + result: List[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size: + result.append(subgraph) + else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def tag(self, subgraphs: List[Subgraph]): + self.tags = [] + for subgraph in subgraphs: + tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: + split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple) + if remove_tag: + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + return split_module # type: ignore[return-value] + + def __call__(self) -> torch.fx.GraphModule: + subgraphs = self.put_nodes_into_subgraphs() + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") + self.tag(subgraphs) + return self.split() + + def generate_split_results(self) -> SplitResult: + split_module = self() + submodule_names = [] + for name, mod in split_module.named_children(): + submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError( + "Cannot fulfill max_acc_splits limit. " + "This may cause split fragmentation and " + "result in performance issues." + ) + + submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) + return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py b/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py new file mode 100644 index 0000000000000000000000000000000000000000..aac071ace8c2daff5c727e462f34bb47ff0f820a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/torch/fx/passes/tools_common.py @@ -0,0 +1,303 @@ +# mypy: allow-untyped-defs +from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional +import collections +from dataclasses import dataclass +import operator + +import torch +import torch.fx +from torch.fx.node import _get_qualified_name +from torch.fx._compatibility import compatibility + +__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] + +Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] +TensorOrTensors = Union[torch.Tensor, Tensors] +NodeList = List[torch.fx.Node] +NodeSet = Set[torch.fx.Node] +Names = List[str] +CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} + + +@compatibility(is_backward_compatible=False) +def get_acc_ops_name(k): + if isinstance(k, str): + return k + elif k.__module__ and "acc_ops" in k.__module__: + return f"acc_ops.{k.__name__}" + else: + module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + return f"{module if module else ''}.{k.__name__}" + + +@compatibility(is_backward_compatible=False) +def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: + """ + Given a `node` returns its target typename. + + For "call_method" node, return node.target which is the name of that method being called. + This could potential lead to conflict but should be okay because normally it's on a tensor. + + For "call_function" node, return typename of node.target. + + For "call_module" node, return typename of the module that node.target point to. + + If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by + "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. + """ + + assert node.op in CALLABLE_NODE_OPS, ( + "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" + ) + + if node.op == "call_module": + assert isinstance(node.target, str) + submod = submodules[node.target] + submod_type = getattr(submod, "_base_class_origin", type(submod)) + return get_acc_ops_name(submod_type) + elif node.op == "call_function": + target: Any = node.target + return ( + f"acc_ops.{target.__name__}" + if target.__module__ is not None and "acc_ops" in target.__module__ + else _get_qualified_name(target) + ) + else: + assert isinstance(node.target, str) + return node.target + +@compatibility(is_backward_compatible=False) +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + +@compatibility(is_backward_compatible=False) +class FxNetAccFusionsFinder: + """ + Finds groups of connected ACC nodes that pass non-tensor data between each other. + Such groups are called fusion groups. + """ + + def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): + self.module = module + self.nodes = list(module.graph.nodes) + self.acc_nodes = acc_nodes + + @dataclass + class FusionGroup: + # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. + top_node_idx: int + + # Nodes in this fusion group. + nodes: NodeSet + + # Inputs to this fusion group. + inputs: NodeSet + + # Nodes that in the fusion group that haven't been processed yet. + nodes_need_process: NodeSet + + def add_node(self, node): + """ + Add a node to fusion group. + """ + if node in self.nodes: + return + + self.nodes_need_process.add(node) + self.nodes.add(node) + self.inputs.discard(node) + self.inputs.update( + { + n + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS and n not in self.nodes + } + ) + + def recursive_add_node( + self, + fusion_group: "FxNetAccFusionsFinder.FusionGroup", + inputs: Union[NodeSet, NodeList], + visited: Optional[NodeSet] = None, + ): + """ + Start from inputs and going reverse topological order. If any upstream node + is in the fusion group, add all the nodes in this path to fusion group. + """ + for arg in inputs: + # skip the node if already seen + if visited is not None: + if arg in visited: + continue + visited.add(arg) + + # Skip placeholder and get_attr because they won't be in the fusion group. + if arg.op not in CALLABLE_NODE_OPS: + continue + + # If the node has smaller idx, it's already an upstream node of the fusion + # group. We don't need to check it anymore. + if self.nodes.index(arg) < fusion_group.top_node_idx: + continue + + # If the node is in the fusion group, return True. + if arg in fusion_group.nodes: + return True + + # Check the upstream nodes of the node, if any of them is in the fusion group + # we'll add this node to fusion group and return True. + if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): + fusion_group.add_node(arg) + return True + + return False + + def __call__(self) -> Dict[torch.fx.Node, NodeSet]: + result: Dict[torch.fx.Node, NodeSet] = {} + acc_nodes = list(self.acc_nodes) + + for node in acc_nodes: + if node in result: + continue + if node.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in node.meta: + continue + if node not in self.acc_nodes: + continue + + fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( + top_node_idx=self.nodes.index(node), + nodes={node}, + inputs=set(node.all_input_nodes), + nodes_need_process={node}, + ) + while fusion_group.nodes_need_process: + node = fusion_group.nodes_need_process.pop() + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Optionally add downstream nodes + if "tensor_meta" not in node.meta: + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + if user in fusion_group.nodes: + continue + + fusion_group.add_node(user) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Add some upstream nodes + for arg in node.all_input_nodes: + if arg.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in arg.meta: + continue + if arg in fusion_group.nodes: + continue + + fusion_group.add_node(arg) + fusion_group.top_node_idx = min( + fusion_group.top_node_idx, self.nodes.index(arg) + ) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + if not (set(fusion_group.nodes) <= self.acc_nodes): + self.acc_nodes -= fusion_group.nodes + else: + for n in fusion_group.nodes: + result[n] = fusion_group.nodes + + return result + + +@compatibility(is_backward_compatible=False) +def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + Returns: + The graph module in-place sorted + """ + + # These operators are used for making runtime assertions before any + # data-dependent operators occur. We want to prioritize sorting these to + # ensure that these assertions appear before any data-dependent operations + # in the graph. + PRIORITIZED_OPS = [ + operator.add, + operator.mul, + operator.sub, + operator.floordiv, + operator.truediv, + operator.mod, + operator.le, + operator.lt, + operator.ge, + operator.gt, + operator.eq, + operator.ne, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.msg, + torch.ops.aten.scalar_tensor.default, + torch.ops.aten._assert_scalar.default, + ] + + indeg = dict.fromkeys(gm.graph.nodes, 0) + new_graph = torch.fx.Graph() + # Track how many unfulfilled dependencies each node has + for node in gm.graph.nodes: + for user in node.users: + indeg[user] += 1 + queue: collections.deque = collections.deque() + # Add all nodes with no dependencies to the queue + for node in gm.graph.nodes: + if indeg[node] == 0: + queue.append(node) + env: Dict[torch.fx.Node, torch.fx.Node] = {} + # Pop nodes from the queue, and add nodes that have had all their + # dependencies fulfilled + while len(queue) > 0: + cur = queue.popleft() + env[cur] = new_graph.node_copy(cur, lambda x: env[x]) + for user in cur.users: + indeg[user] -= 1 + if indeg[user] == 0: + if user.op == "call_function" and user.target in PRIORITIZED_OPS: + queue.appendleft(user) + else: + queue.append(user) + # If the new graph's size is not as large as the old one, then there must be + # a cycle (i.e. some node's dependencies were not satisfied.) + if len(new_graph.nodes) < len(gm.graph.nodes): + raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") + new_graph._codegen = gm.graph._codegen + gm.graph = new_graph + return gm