Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py +558 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +1281 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +1040 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py +14 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +349 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py +53 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py +29 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py +4 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py +119 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py +6 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py +122 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py +118 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py +121 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py +84 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py +427 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py +126 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py +92 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py +396 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py +106 -0
- .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py +86 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py +12 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py +44 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py +0 -0
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc
ADDED
|
Binary file (28.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc
ADDED
|
Binary file (72.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc
ADDED
|
Binary file (52.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc
ADDED
|
Binary file (521 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (2.43 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
|
| 3 |
+
op_mod, op_gt, op_lt, op_neq, op_eq
|
| 4 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Constraint:
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Conj(Constraint):
|
| 12 |
+
def __init__(self, conjuncts):
|
| 13 |
+
"""
|
| 14 |
+
:param conjuncts: Conjunction of constraints
|
| 15 |
+
"""
|
| 16 |
+
self.conjucts = conjuncts
|
| 17 |
+
|
| 18 |
+
def __eq__(self, other):
|
| 19 |
+
if isinstance(other, Conj):
|
| 20 |
+
return self.conjucts == other.conjucts and self.conjucts == other.conjucts
|
| 21 |
+
else:
|
| 22 |
+
return False
|
| 23 |
+
|
| 24 |
+
def __repr__(self):
|
| 25 |
+
return f'And({self.conjucts})'
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Disj(Constraint):
|
| 29 |
+
def __init__(self, disjuncts):
|
| 30 |
+
"""
|
| 31 |
+
:param disjuncts: Disjunction of constraints
|
| 32 |
+
"""
|
| 33 |
+
self.disjuncts = disjuncts
|
| 34 |
+
|
| 35 |
+
def __eq__(self, other):
|
| 36 |
+
if isinstance(other, Disj):
|
| 37 |
+
return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
|
| 38 |
+
else:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def __repr__(self):
|
| 42 |
+
return f'Or({self.disjuncts})'
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Prod(Constraint):
|
| 46 |
+
def __init__(self, products):
|
| 47 |
+
"""
|
| 48 |
+
:param products: lists of dimensions to multiply
|
| 49 |
+
"""
|
| 50 |
+
self.products = products
|
| 51 |
+
|
| 52 |
+
def __eq__(self, other):
|
| 53 |
+
if isinstance(other, Prod):
|
| 54 |
+
return self.products == other.products and self.products == other.products
|
| 55 |
+
else:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
def __repr__(self):
|
| 59 |
+
return f'Product({self.products})'
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class T(Constraint):
|
| 63 |
+
"""
|
| 64 |
+
True
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self) -> None:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
def __eq__(self, other):
|
| 70 |
+
return isinstance(other, T)
|
| 71 |
+
|
| 72 |
+
def __repr__(self):
|
| 73 |
+
return 'True'
|
| 74 |
+
|
| 75 |
+
class F(Constraint):
|
| 76 |
+
"""
|
| 77 |
+
False
|
| 78 |
+
"""
|
| 79 |
+
def __init__(self) -> None:
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
def __eq__(self, other):
|
| 83 |
+
return isinstance(other, F)
|
| 84 |
+
|
| 85 |
+
def __repr__(self):
|
| 86 |
+
return 'False'
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BinaryConstraint(Constraint):
|
| 90 |
+
"""
|
| 91 |
+
Represents all binary operations
|
| 92 |
+
"""
|
| 93 |
+
def __init__(self, lhs, rhs, op):
|
| 94 |
+
"""
|
| 95 |
+
:param lhs: lhs of the constraint
|
| 96 |
+
:param rhs: rhs of the constraint
|
| 97 |
+
:param op: string representing the operation
|
| 98 |
+
"""
|
| 99 |
+
self.lhs = lhs
|
| 100 |
+
self.rhs = rhs
|
| 101 |
+
self.op = op
|
| 102 |
+
|
| 103 |
+
def __eq__(self, other):
|
| 104 |
+
if isinstance(other, BinaryConstraint):
|
| 105 |
+
return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
|
| 106 |
+
else:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
def __repr__(self):
|
| 110 |
+
return f'({self.lhs} {self.op} {self.rhs})'
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class BinConstraintT(BinaryConstraint):
|
| 114 |
+
"""
|
| 115 |
+
Binary constraints about tensors
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, lhs, rhs, op):
|
| 118 |
+
assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
|
| 119 |
+
(isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
|
| 120 |
+
super().__init__(lhs, rhs, op)
|
| 121 |
+
|
| 122 |
+
def __eq__(self, other):
|
| 123 |
+
return super().__eq__(other)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class BinConstraintD(BinaryConstraint):
|
| 127 |
+
"""
|
| 128 |
+
Binary constraints about dimensions
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self, lhs, rhs, op):
|
| 131 |
+
assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
|
| 132 |
+
assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
|
| 133 |
+
|
| 134 |
+
super().__init__(lhs, rhs, op)
|
| 135 |
+
|
| 136 |
+
def __eq__(self, other):
|
| 137 |
+
return super().__eq__(other)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TGreatestUpperBound(Constraint):
|
| 142 |
+
"""
|
| 143 |
+
Greatest Upper bound for tensors with dynamic type
|
| 144 |
+
"""
|
| 145 |
+
def __init__(self, res, rhs1, rhs2):
|
| 146 |
+
"""
|
| 147 |
+
:param res: tensor variable that stores the result of the outout
|
| 148 |
+
:param rhs1: tensor or tensor variable
|
| 149 |
+
:param rhs2: tensor or tensor variabke
|
| 150 |
+
"""
|
| 151 |
+
self.res = res
|
| 152 |
+
self.rhs1 = rhs1
|
| 153 |
+
self.rhs2 = rhs2
|
| 154 |
+
|
| 155 |
+
def __repr__(self):
|
| 156 |
+
return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
|
| 157 |
+
|
| 158 |
+
def __eq__(self, other):
|
| 159 |
+
if isinstance(other, TGreatestUpperBound):
|
| 160 |
+
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
| 161 |
+
else:
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DGreatestUpperBound(Constraint):
|
| 166 |
+
"""
|
| 167 |
+
Greatest Upper bound for dimensions
|
| 168 |
+
"""
|
| 169 |
+
def __init__(self, res, rhs1, rhs2):
|
| 170 |
+
"""
|
| 171 |
+
:param res: Dimension variable to store the result
|
| 172 |
+
:param rhs1: dimension variable 1
|
| 173 |
+
:param rhs2: dimension variable 2
|
| 174 |
+
"""
|
| 175 |
+
assert is_dim(res)
|
| 176 |
+
assert is_dim(rhs1)
|
| 177 |
+
assert is_dim(rhs2)
|
| 178 |
+
|
| 179 |
+
self.res = res
|
| 180 |
+
self.rhs1 = rhs1
|
| 181 |
+
self.rhs2 = rhs2
|
| 182 |
+
|
| 183 |
+
def __repr__(self):
|
| 184 |
+
return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
|
| 185 |
+
|
| 186 |
+
def __eq__(self, other):
|
| 187 |
+
if isinstance(other, DGreatestUpperBound):
|
| 188 |
+
return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
|
| 189 |
+
else:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class CanReshape(Constraint):
|
| 194 |
+
"""
|
| 195 |
+
can_reshape constraint
|
| 196 |
+
"""
|
| 197 |
+
def __init__(self, src, target):
|
| 198 |
+
"""
|
| 199 |
+
:param src: tensor variable
|
| 200 |
+
:param target: tensor
|
| 201 |
+
"""
|
| 202 |
+
self.src = src
|
| 203 |
+
self.target = target
|
| 204 |
+
|
| 205 |
+
def __repr__(self):
|
| 206 |
+
return f'can-reshape({self.src}, {self.target})'
|
| 207 |
+
|
| 208 |
+
def __eq__(self, other):
|
| 209 |
+
if isinstance(other, CanReshape):
|
| 210 |
+
return self.src == other.src and self.target == other.target
|
| 211 |
+
else:
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class IndexSelect(Constraint):
|
| 216 |
+
|
| 217 |
+
def __init__(self, tensor_size, input_var, dim_replace, index, output):
|
| 218 |
+
"""
|
| 219 |
+
Args:
|
| 220 |
+
input_var: input to index_select
|
| 221 |
+
tensor_size: tensor size we are considering
|
| 222 |
+
dim_replace: the dimension of the output at "index"
|
| 223 |
+
index: location of the dimensions to replace in the input
|
| 224 |
+
output: variable to store the result
|
| 225 |
+
"""
|
| 226 |
+
assert isinstance(input_var, TVar)
|
| 227 |
+
assert isinstance(output, TVar)
|
| 228 |
+
assert isinstance(dim_replace, DVar) or dim_replace == Dyn
|
| 229 |
+
assert isinstance(index, int)
|
| 230 |
+
|
| 231 |
+
self.input_var = input_var
|
| 232 |
+
self.tensor_size = tensor_size
|
| 233 |
+
self.dim_replace = dim_replace
|
| 234 |
+
self.index = index
|
| 235 |
+
self.output = output
|
| 236 |
+
|
| 237 |
+
def __repr__(self):
|
| 238 |
+
|
| 239 |
+
return f' {self.output} = ' \
|
| 240 |
+
f'IndexSelect({self.input_var}, ' \
|
| 241 |
+
f'tensor_size: {self.tensor_size}, ' \
|
| 242 |
+
f'{self.dim_replace}, ' \
|
| 243 |
+
f'{self.index})'
|
| 244 |
+
|
| 245 |
+
def __eq__(self, other):
|
| 246 |
+
if isinstance(other, IndexSelect):
|
| 247 |
+
return self.tensor_size == other.tensor_size and \
|
| 248 |
+
self.dim_replace == other.dim_replace and \
|
| 249 |
+
self.index == other.index and \
|
| 250 |
+
self.output == other.output and \
|
| 251 |
+
self.input_var == other.input_var
|
| 252 |
+
else:
|
| 253 |
+
return False
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class Transpose(Constraint):
|
| 257 |
+
|
| 258 |
+
def __init__(self, tensor_size, input_var, index1, index2, output):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
tensor_size: current tensor size
|
| 262 |
+
input_var: variable to hold input
|
| 263 |
+
index1: dimension 1
|
| 264 |
+
index2: dimension 2
|
| 265 |
+
output: output that stores result
|
| 266 |
+
"""
|
| 267 |
+
assert isinstance(input_var, TVar)
|
| 268 |
+
assert isinstance(output, TVar)
|
| 269 |
+
assert isinstance(index1, int)
|
| 270 |
+
assert isinstance(index2, int)
|
| 271 |
+
|
| 272 |
+
self.input_var = input_var
|
| 273 |
+
self.tensor_size = tensor_size
|
| 274 |
+
self.index1 = index1
|
| 275 |
+
self.index2 = index2
|
| 276 |
+
self.output = output
|
| 277 |
+
|
| 278 |
+
def __repr__(self):
|
| 279 |
+
|
| 280 |
+
return f' {self.output} = ' \
|
| 281 |
+
f'Transpose({self.input_var}, ' \
|
| 282 |
+
f'tensor_size: {self.tensor_size}, ' \
|
| 283 |
+
f'{self.index1}, ' \
|
| 284 |
+
f'{self.index2})'
|
| 285 |
+
|
| 286 |
+
def __eq__(self, other):
|
| 287 |
+
if isinstance(other, Transpose):
|
| 288 |
+
return self.tensor_size == other.tensor_size and \
|
| 289 |
+
self.index1 == other.index1 and \
|
| 290 |
+
self.index2 == other.index2 and \
|
| 291 |
+
self.output == other.output and \
|
| 292 |
+
self.input_var == other.input_var
|
| 293 |
+
else:
|
| 294 |
+
return False
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class GetItem(Constraint):
|
| 298 |
+
|
| 299 |
+
def __init__(self, tensor_size, index, res, input_var):
|
| 300 |
+
"""
|
| 301 |
+
Constraint for getting item given a tensor size
|
| 302 |
+
:param tensor_size: actual number
|
| 303 |
+
:param index: actual number representing the index
|
| 304 |
+
:param res: dimension variable to carry the item we get
|
| 305 |
+
:param input_var: a tensor variable from which we will get item
|
| 306 |
+
"""
|
| 307 |
+
assert isinstance(res, DVar)
|
| 308 |
+
|
| 309 |
+
self.res = res
|
| 310 |
+
self.tensor_size = tensor_size
|
| 311 |
+
self.index = index
|
| 312 |
+
self.input_var = input_var
|
| 313 |
+
|
| 314 |
+
def __repr__(self):
|
| 315 |
+
return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
|
| 316 |
+
|
| 317 |
+
def __eq__(self, other):
|
| 318 |
+
if isinstance(other, GetItem):
|
| 319 |
+
return self.res == other.res and \
|
| 320 |
+
self.tensor_size == other.tensor_size and \
|
| 321 |
+
self.index == other.index and \
|
| 322 |
+
self.input_var == other.input_var
|
| 323 |
+
else:
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
class GetItemTensor(Constraint):
|
| 327 |
+
|
| 328 |
+
def __init__(self, tensor_size, index_tuple, res, input_var):
|
| 329 |
+
"""
|
| 330 |
+
Constraint for getting item given a tensor size
|
| 331 |
+
However, when the argument is a tuple, we will
|
| 332 |
+
expect a tensor
|
| 333 |
+
:param tensor_size: actual number representing the rank
|
| 334 |
+
:param index_tuple: tuple for indexing
|
| 335 |
+
:param res: tensor variable to carry the item we get
|
| 336 |
+
:param input_var: a tensor variable from which we will get item
|
| 337 |
+
"""
|
| 338 |
+
assert isinstance(res, TVar)
|
| 339 |
+
|
| 340 |
+
self.res = res
|
| 341 |
+
self.tensor_size = tensor_size
|
| 342 |
+
self.index_tuple = index_tuple
|
| 343 |
+
self.input_var = input_var
|
| 344 |
+
|
| 345 |
+
def __repr__(self):
|
| 346 |
+
return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
|
| 347 |
+
|
| 348 |
+
def __eq__(self, other):
|
| 349 |
+
if isinstance(other, GetItemTensor):
|
| 350 |
+
return self.res == other.res and \
|
| 351 |
+
self.tensor_size == other.tensor_size and \
|
| 352 |
+
self.index_tuple == other.index_tuple and \
|
| 353 |
+
self.input_var == other.input_var
|
| 354 |
+
else:
|
| 355 |
+
return False
|
| 356 |
+
|
| 357 |
+
class CalcConv(Constraint):
|
| 358 |
+
|
| 359 |
+
def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
|
| 360 |
+
"""
|
| 361 |
+
:param conv_result: the convolution result
|
| 362 |
+
:param input_var: input to convolution
|
| 363 |
+
:param c_out: output chanel type
|
| 364 |
+
:param kernel: kernel tuple
|
| 365 |
+
"""
|
| 366 |
+
self.conv_result = conv_result
|
| 367 |
+
self.input_var = input_var
|
| 368 |
+
self.c_out = c_out
|
| 369 |
+
self.kernel = kernel
|
| 370 |
+
self.padding = padding
|
| 371 |
+
self.stride = stride
|
| 372 |
+
self.dilation = dilation
|
| 373 |
+
self.matching_constraint = matching_constraint_vars
|
| 374 |
+
|
| 375 |
+
def __repr__(self):
|
| 376 |
+
return f'{self.conv_result} =' \
|
| 377 |
+
f' calc-conv({self.input_var},' \
|
| 378 |
+
f' {self.c_out}, {self.kernel}, ' \
|
| 379 |
+
f'{self.padding}, {self.stride},' \
|
| 380 |
+
f' {self.dilation})'
|
| 381 |
+
|
| 382 |
+
def __eq__(self, other):
|
| 383 |
+
if isinstance(other, CalcConv):
|
| 384 |
+
return self.conv_result == other.conv_result and self.input_var == other.input_var and \
|
| 385 |
+
self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
|
| 386 |
+
and self.stride == other.stride and self.dilation == other.dilation \
|
| 387 |
+
and self.matching_constraint == other.matching_constraint
|
| 388 |
+
else:
|
| 389 |
+
return False
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class CalcMaxPool(Constraint):
|
| 393 |
+
|
| 394 |
+
def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
|
| 395 |
+
"""
|
| 396 |
+
:param maxpool_result: the result of maxpool
|
| 397 |
+
:param input_var: input to convolution
|
| 398 |
+
:param kernel: kernel tuple
|
| 399 |
+
"""
|
| 400 |
+
self.maxpool_result = maxpool_result
|
| 401 |
+
self.input_var = input_var
|
| 402 |
+
self.kernel = kernel
|
| 403 |
+
self.padding = padding
|
| 404 |
+
self.stride = stride
|
| 405 |
+
self.dilation = dilation
|
| 406 |
+
self.matching_constraint = matching_constraint_vars
|
| 407 |
+
|
| 408 |
+
def __repr__(self):
|
| 409 |
+
return f'{self.maxpool_result} =' \
|
| 410 |
+
f' calc-maxpool({self.input_var},' \
|
| 411 |
+
f' {self.kernel}, ' \
|
| 412 |
+
f'{self.padding}, {self.stride},' \
|
| 413 |
+
f' {self.dilation})'
|
| 414 |
+
|
| 415 |
+
def __eq__(self, other):
|
| 416 |
+
if isinstance(other, CalcMaxPool):
|
| 417 |
+
return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
|
| 418 |
+
and self.kernel == other.kernel and self.padding == other.padding \
|
| 419 |
+
and self.stride == other.stride and self.dilation == other.dilation \
|
| 420 |
+
and self.matching_constraint == other.matching_constraint
|
| 421 |
+
else:
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class ApplyBroadcasting(Constraint):
|
| 426 |
+
def __init__(self, res1, res2, input1, input2):
|
| 427 |
+
"""
|
| 428 |
+
:param res1: resulting tensor 1
|
| 429 |
+
:param res2: resulting tensor 2
|
| 430 |
+
:param input1: tensor variable 1
|
| 431 |
+
:param input2: tensor variable 2
|
| 432 |
+
"""
|
| 433 |
+
self.res1 = res1
|
| 434 |
+
self.res2 = res2
|
| 435 |
+
self.input1 = input1
|
| 436 |
+
self.input2 = input2
|
| 437 |
+
|
| 438 |
+
def __eq__(self, other):
|
| 439 |
+
if isinstance(other, ApplyBroadcasting):
|
| 440 |
+
return self.res1 == other.res1 \
|
| 441 |
+
and self.res2 == other.res2 \
|
| 442 |
+
and self.input1 == other.input1 \
|
| 443 |
+
and self.input2 == other.input2
|
| 444 |
+
else:
|
| 445 |
+
return False
|
| 446 |
+
|
| 447 |
+
def __repr__(self):
|
| 448 |
+
return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class CalcProduct(Constraint):
|
| 452 |
+
"""
|
| 453 |
+
Given correct dimensions, calculate the product for flatten accounting for Dyn
|
| 454 |
+
"""
|
| 455 |
+
def __init__(self, start, end, flattened, dims_to_flatten):
|
| 456 |
+
"""
|
| 457 |
+
:param start: start index
|
| 458 |
+
:param end: end index
|
| 459 |
+
:param flattened: variable to store the product
|
| 460 |
+
:param dims_to_flatten: the type which we will flatten
|
| 461 |
+
"""
|
| 462 |
+
assert isinstance(dims_to_flatten, list)
|
| 463 |
+
assert isinstance(flattened, TVar)
|
| 464 |
+
assert isinstance(start, int)
|
| 465 |
+
assert isinstance(end, int)
|
| 466 |
+
|
| 467 |
+
self.start = start
|
| 468 |
+
self.end = end
|
| 469 |
+
self.dims_to_flatten = dims_to_flatten
|
| 470 |
+
self.flattened = flattened
|
| 471 |
+
|
| 472 |
+
def __eq__(self, other):
|
| 473 |
+
if isinstance(other, CalcProduct):
|
| 474 |
+
return self.start == other.start and self.end == other.end and \
|
| 475 |
+
self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
|
| 476 |
+
|
| 477 |
+
else:
|
| 478 |
+
return False
|
| 479 |
+
|
| 480 |
+
def __repr__(self):
|
| 481 |
+
return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class TVar:
|
| 485 |
+
"""
|
| 486 |
+
Tensor variable with no tensor constructor
|
| 487 |
+
"""
|
| 488 |
+
def __init__(self, tvar):
|
| 489 |
+
"""
|
| 490 |
+
:param tvar: tensor variable
|
| 491 |
+
"""
|
| 492 |
+
self.tvar = tvar
|
| 493 |
+
|
| 494 |
+
def __repr__(self):
|
| 495 |
+
return f'TV({self.tvar})'
|
| 496 |
+
|
| 497 |
+
def __eq__(self, other):
|
| 498 |
+
if isinstance(other, TVar):
|
| 499 |
+
return self.tvar == other.tvar
|
| 500 |
+
else:
|
| 501 |
+
return False
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class DVar:
|
| 505 |
+
"""
|
| 506 |
+
Dimension variable
|
| 507 |
+
"""
|
| 508 |
+
def __init__(self, c):
|
| 509 |
+
"""
|
| 510 |
+
:param c: character or number
|
| 511 |
+
"""
|
| 512 |
+
self.c = c
|
| 513 |
+
|
| 514 |
+
def __repr__(self):
|
| 515 |
+
return f'DV({self.c})'
|
| 516 |
+
|
| 517 |
+
def __eq__(self, other):
|
| 518 |
+
if isinstance(other, DVar):
|
| 519 |
+
return self.c == other.c
|
| 520 |
+
else:
|
| 521 |
+
return False
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class BVar:
|
| 525 |
+
"""
|
| 526 |
+
Boolean variable
|
| 527 |
+
"""
|
| 528 |
+
def __init__(self, c):
|
| 529 |
+
"""
|
| 530 |
+
:param c: character or number
|
| 531 |
+
"""
|
| 532 |
+
self.c = c
|
| 533 |
+
|
| 534 |
+
def __repr__(self):
|
| 535 |
+
return f'BV({self.c})'
|
| 536 |
+
|
| 537 |
+
def __eq__(self, other):
|
| 538 |
+
if isinstance(other, BVar):
|
| 539 |
+
return self.c == other.c
|
| 540 |
+
else:
|
| 541 |
+
return False
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def is_algebraic_expression(constraint):
|
| 545 |
+
if isinstance(constraint, BinConstraintD):
|
| 546 |
+
return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
|
| 547 |
+
else:
|
| 548 |
+
return isinstance(constraint, Prod)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def is_bool_expr(constraint):
|
| 552 |
+
if isinstance(constraint, BinConstraintD):
|
| 553 |
+
return constraint.op in [op_gt, op_lt, op_neq, op_eq]
|
| 554 |
+
else:
|
| 555 |
+
return isinstance(constraint, (BVar, Conj, Disj))
|
| 556 |
+
|
| 557 |
+
def is_dim(d):
|
| 558 |
+
return isinstance(d, (DVar, int)) or d == Dyn
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
ADDED
|
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-decorators
|
| 2 |
+
# mypy: allow-untyped-defs
|
| 3 |
+
import torch
|
| 4 |
+
import operator
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Callable, Dict, Iterable
|
| 7 |
+
|
| 8 |
+
from torch.fx._symbolic_trace import _assert_is_none
|
| 9 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
|
| 10 |
+
Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
|
| 11 |
+
TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
|
| 12 |
+
from torch.fx.experimental.migrate_gradual_types.operation import \
|
| 13 |
+
op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
|
| 14 |
+
from torch.fx.node import Target, Node
|
| 15 |
+
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
|
| 16 |
+
gen_bvar
|
| 17 |
+
|
| 18 |
+
from torch.fx.tensor_type import Dyn, TensorType
|
| 19 |
+
from torch.nn.modules.conv import Conv2d
|
| 20 |
+
from torch.nn.modules.batchnorm import BatchNorm2d
|
| 21 |
+
|
| 22 |
+
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
| 23 |
+
|
| 24 |
+
MAX_TENSOR_RANK = 4
|
| 25 |
+
|
| 26 |
+
def register_inference_rule(call_target):
|
| 27 |
+
def register(fn):
|
| 28 |
+
if call_target in _INFERENCE_RULES:
|
| 29 |
+
raise RuntimeError(f'Inference rule already registered for {call_target}!')
|
| 30 |
+
_INFERENCE_RULES[call_target] = fn
|
| 31 |
+
return fn
|
| 32 |
+
return register
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
|
| 36 |
+
d, counter = gen_tensor_dims(n, counter)
|
| 37 |
+
c1 = BinConstraintT(input, TensorType(d), op_eq)
|
| 38 |
+
start_dim = n if start_dim == -1 else abs(start_dim)
|
| 39 |
+
end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
|
| 40 |
+
c2 = CalcProduct(start_dim, end_dim, flattened, d)
|
| 41 |
+
nat_constraints = gen_nat_constraints(d)
|
| 42 |
+
return Conj([c1, c2, *nat_constraints]), counter
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@register_inference_rule(getattr)
|
| 46 |
+
def get_attr_inference_rule(n: Node, symbols, constraints, counter):
|
| 47 |
+
"""
|
| 48 |
+
If the attribute is "device" then the tensor shape is preserved
|
| 49 |
+
"""
|
| 50 |
+
assert isinstance(n.args[0], Node)
|
| 51 |
+
assert isinstance(n.args[1], str)
|
| 52 |
+
output, counter = gen_tvar(counter)
|
| 53 |
+
symbols[n] = output
|
| 54 |
+
|
| 55 |
+
input = symbols[n.args[0]]
|
| 56 |
+
attr = n.args[1]
|
| 57 |
+
|
| 58 |
+
if attr == 'device':
|
| 59 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 60 |
+
else:
|
| 61 |
+
raise NotImplementedError('Not yet implemented')
|
| 62 |
+
|
| 63 |
+
@register_inference_rule(torch.bmm)
|
| 64 |
+
def bmm_inference_rule(n: Node, symbols, constraints, counter):
|
| 65 |
+
"""
|
| 66 |
+
Constraints that match the input to a size 3 tensor
|
| 67 |
+
and switch the dimensions according to the rules
|
| 68 |
+
of batch multiplication
|
| 69 |
+
"""
|
| 70 |
+
assert isinstance(n.args[0], Node)
|
| 71 |
+
assert isinstance(n.args[1], Node)
|
| 72 |
+
|
| 73 |
+
bmm_output, counter = gen_tvar(counter)
|
| 74 |
+
symbols[n] = bmm_output
|
| 75 |
+
|
| 76 |
+
bmm_input1 = symbols[n.args[0]]
|
| 77 |
+
bmm_input2 = symbols[n.args[1]]
|
| 78 |
+
|
| 79 |
+
dims_input1, counter = gen_tensor_dims(3, counter)
|
| 80 |
+
dims_input2, counter = gen_tensor_dims(3, counter)
|
| 81 |
+
|
| 82 |
+
inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
| 83 |
+
BinConstraintT(bmm_input2, Dyn, op_eq),
|
| 84 |
+
BinConstraintT(bmm_output, Dyn, op_eq)])
|
| 85 |
+
|
| 86 |
+
input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
|
| 87 |
+
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
| 88 |
+
BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
|
| 89 |
+
|
| 90 |
+
input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
|
| 91 |
+
BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
| 92 |
+
BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
|
| 93 |
+
|
| 94 |
+
consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
|
| 95 |
+
|
| 96 |
+
batch_size, counter = gen_dvar(counter)
|
| 97 |
+
|
| 98 |
+
inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
|
| 99 |
+
BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
|
| 100 |
+
BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
|
| 101 |
+
*consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
|
| 102 |
+
|
| 103 |
+
return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@register_inference_rule("index_select")
|
| 107 |
+
def index_select_inference_rule(n: Node, symbols, constraints, counter):
|
| 108 |
+
"""
|
| 109 |
+
We constrain the second argument to a vector or Dyn.
|
| 110 |
+
The output replaces the input with the shape of the vector
|
| 111 |
+
at the position given by the index (first argument)
|
| 112 |
+
"""
|
| 113 |
+
# print(n.args)
|
| 114 |
+
assert isinstance(n.args[0], Node)
|
| 115 |
+
assert isinstance(n.args[1], int)
|
| 116 |
+
assert isinstance(n.args[2], Node)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
index_select, counter = gen_tvar(counter)
|
| 121 |
+
symbols[n] = index_select
|
| 122 |
+
|
| 123 |
+
dims, counter = gen_tensor_dims(1, counter)
|
| 124 |
+
|
| 125 |
+
# equality constraint
|
| 126 |
+
is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
|
| 127 |
+
is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
|
| 128 |
+
|
| 129 |
+
c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
|
| 130 |
+
for i in range(MAX_TENSOR_RANK)])])
|
| 131 |
+
c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
|
| 132 |
+
for i in range(MAX_TENSOR_RANK)])])
|
| 133 |
+
|
| 134 |
+
return [Disj([c2, c3])], counter
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@register_inference_rule("expand")
|
| 138 |
+
def expand_inference_rule(n: Node, symbols, constraints, counter):
|
| 139 |
+
"""
|
| 140 |
+
We generate the exact constraints as we do for tensor additions but we constraint
|
| 141 |
+
the rank of this expression to be equal to len(n.args[1:]) so that only
|
| 142 |
+
those cases get considered for the output
|
| 143 |
+
"""
|
| 144 |
+
assert isinstance(n.args[0], Node)
|
| 145 |
+
|
| 146 |
+
# define the output for expand
|
| 147 |
+
expand, counter = gen_tvar(counter)
|
| 148 |
+
symbols[n] = expand
|
| 149 |
+
|
| 150 |
+
# since we do not have two nodes here, we will construct an argument variable
|
| 151 |
+
e1 = symbols[n.args[0]]
|
| 152 |
+
e2, counter = gen_tvar(counter)
|
| 153 |
+
|
| 154 |
+
e2_nat_constraints = []
|
| 155 |
+
for arg in n.args[1:]:
|
| 156 |
+
assert isinstance(arg, (Node, int))
|
| 157 |
+
if isinstance(arg, Node):
|
| 158 |
+
assert isinstance(symbols[arg], DVar)
|
| 159 |
+
e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
|
| 160 |
+
|
| 161 |
+
e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
|
| 162 |
+
|
| 163 |
+
constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
|
| 164 |
+
|
| 165 |
+
# constraint the output size
|
| 166 |
+
dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
|
| 167 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 168 |
+
c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
|
| 169 |
+
constraints += c
|
| 170 |
+
|
| 171 |
+
return constraints, counter
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@register_inference_rule(torch.nn.functional.gelu)
|
| 175 |
+
@register_inference_rule(torch.nn.functional.dropout)
|
| 176 |
+
@register_inference_rule(torch.nn.functional.softmax)
|
| 177 |
+
@register_inference_rule("detach")
|
| 178 |
+
@register_inference_rule("to")
|
| 179 |
+
@register_inference_rule("int")
|
| 180 |
+
@register_inference_rule("long")
|
| 181 |
+
@register_inference_rule("contiguous")
|
| 182 |
+
@register_inference_rule(torch.ones)
|
| 183 |
+
@register_inference_rule(torch.zeros)
|
| 184 |
+
def equality_inference_rule(n: Node, symbols, constraints, counter):
|
| 185 |
+
"""
|
| 186 |
+
We generate the constraint: input = output
|
| 187 |
+
"""
|
| 188 |
+
output, counter = gen_tvar(counter)
|
| 189 |
+
symbols[n] = output
|
| 190 |
+
|
| 191 |
+
if isinstance(n.args[0], Node):
|
| 192 |
+
input = symbols[n.args[0]]
|
| 193 |
+
if isinstance(input, TVar):
|
| 194 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 195 |
+
|
| 196 |
+
# then we have dimension variables
|
| 197 |
+
else:
|
| 198 |
+
for arg in n.args:
|
| 199 |
+
assert isinstance(symbols[arg], DVar)
|
| 200 |
+
my_size = [symbols[arg] for arg in n.args]
|
| 201 |
+
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
|
| 202 |
+
|
| 203 |
+
elif isinstance(n.args[0], tuple):
|
| 204 |
+
# then the tuple is the size
|
| 205 |
+
assert len(n.args[0]) <= 4
|
| 206 |
+
my_size = [symbols[arg] for arg in n.args[0]]
|
| 207 |
+
return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
|
| 208 |
+
else:
|
| 209 |
+
raise NotImplementedError('Method not yet implemented')
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@register_inference_rule("transpose")
|
| 213 |
+
def transpose_inference_rule(n: Node, symbols, constraints, counter):
|
| 214 |
+
"""
|
| 215 |
+
Can be considered as a sequence of two index selects, so we generate constraints accordingly
|
| 216 |
+
"""
|
| 217 |
+
assert isinstance(n.args[0], Node)
|
| 218 |
+
assert isinstance(n.args[1], int)
|
| 219 |
+
assert isinstance(n.args[2], int)
|
| 220 |
+
|
| 221 |
+
output, counter = gen_tvar(counter)
|
| 222 |
+
symbols[n] = output
|
| 223 |
+
|
| 224 |
+
from_arg = symbols[n.args[0]]
|
| 225 |
+
assert isinstance(from_arg, TVar)
|
| 226 |
+
|
| 227 |
+
# input and output are dyn
|
| 228 |
+
is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
|
| 229 |
+
|
| 230 |
+
# or input is a tensor and we actually do the replacement
|
| 231 |
+
c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
|
| 232 |
+
|
| 233 |
+
return [Disj([is_dyn, c3])], counter
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@register_inference_rule("type_as")
|
| 237 |
+
def type_inference_rule(n: Node, symbols, constraints, counter):
|
| 238 |
+
"""
|
| 239 |
+
We generate the constraint: input = output
|
| 240 |
+
"""
|
| 241 |
+
assert isinstance(n.args[0], Node)
|
| 242 |
+
assert isinstance(n.args[1], Node)
|
| 243 |
+
|
| 244 |
+
output, counter = gen_tvar(counter)
|
| 245 |
+
symbols[n] = output
|
| 246 |
+
|
| 247 |
+
from_arg = symbols[n.args[0]]
|
| 248 |
+
to_arg = symbols[n.args[1]]
|
| 249 |
+
|
| 250 |
+
assert isinstance(from_arg, TVar)
|
| 251 |
+
assert isinstance(to_arg, TVar)
|
| 252 |
+
|
| 253 |
+
return [BinConstraintT(from_arg, to_arg, op_consistency),
|
| 254 |
+
BinConstraintT(output, to_arg, op_eq)], counter
|
| 255 |
+
|
| 256 |
+
@register_inference_rule("masked_fill_")
|
| 257 |
+
def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
|
| 258 |
+
"""
|
| 259 |
+
Similar to addition. For now we implement the constraints when
|
| 260 |
+
the argument is a boolean tensor. There is also a case for when
|
| 261 |
+
it is a condition. We will leave this out for now.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
assert isinstance(n.args[0], Node)
|
| 265 |
+
assert isinstance(n.args[1], Node)
|
| 266 |
+
|
| 267 |
+
# We will retrieve the type variables from the symbol table
|
| 268 |
+
# and confirm they are tensor variables
|
| 269 |
+
|
| 270 |
+
e1 = symbols[n.args[0]]
|
| 271 |
+
e2 = symbols[n.args[1]]
|
| 272 |
+
|
| 273 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 274 |
+
masked_fill_tensor, counter = gen_tvar(counter)
|
| 275 |
+
symbols[n] = masked_fill_tensor
|
| 276 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
|
| 277 |
+
else:
|
| 278 |
+
raise NotImplementedError('Not yet implemented')
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@register_inference_rule(torch.nn.functional.embedding)
|
| 282 |
+
def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
|
| 283 |
+
assert isinstance(n.args[0], Node)
|
| 284 |
+
|
| 285 |
+
embedding_dim_weights = symbols[n.args[1]]
|
| 286 |
+
|
| 287 |
+
# will treat this as a static shape. So we will not use matching.
|
| 288 |
+
weight_dims, counter = gen_tensor_dims(2, counter)
|
| 289 |
+
equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
|
| 290 |
+
embedding_dim = weight_dims[1]
|
| 291 |
+
constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
|
| 292 |
+
return [equality_constraint] + constraints, counter
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@register_inference_rule(torch.nn.modules.sparse.Embedding)
|
| 296 |
+
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 297 |
+
"""
|
| 298 |
+
The output shape differs from the input shape in the last dimension
|
| 299 |
+
"""
|
| 300 |
+
assert isinstance(n.args[0], Node)
|
| 301 |
+
return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
|
| 305 |
+
|
| 306 |
+
embedding_output, counter = gen_tvar(counter)
|
| 307 |
+
symbols[n] = embedding_output
|
| 308 |
+
embedding_input = symbols[n.args[0]]
|
| 309 |
+
|
| 310 |
+
input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
|
| 311 |
+
output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
|
| 312 |
+
|
| 313 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 314 |
+
c2 = []
|
| 315 |
+
|
| 316 |
+
for i in range(1, MAX_TENSOR_RANK):
|
| 317 |
+
new_dims, counter = gen_tensor_dims(i, counter)
|
| 318 |
+
nat_constraints = gen_nat_constraints(new_dims)
|
| 319 |
+
|
| 320 |
+
# we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
|
| 321 |
+
c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
|
| 322 |
+
BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
|
| 323 |
+
nat_constraints)
|
| 324 |
+
c2.append(c_tensor_i)
|
| 325 |
+
|
| 326 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
@register_inference_rule(torch.tensor)
|
| 330 |
+
def tensor_inference_rule(n: Node, symbols, constraints, counter):
|
| 331 |
+
"""
|
| 332 |
+
If the tensor is a scalar, we will skip it since we
|
| 333 |
+
do not support scalars yet. We will add support in the future
|
| 334 |
+
if it's needed. For our examples so far, scalars are not needed.
|
| 335 |
+
"""
|
| 336 |
+
return [], counter
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@register_inference_rule("reshape")
|
| 340 |
+
@register_inference_rule("view")
|
| 341 |
+
def view_inference_rule(n: Node, symbols, constraints, counter):
|
| 342 |
+
"""
|
| 343 |
+
Similar to reshape but with an extra condition on the strides
|
| 344 |
+
"""
|
| 345 |
+
assert isinstance(n.args[0], Node)
|
| 346 |
+
|
| 347 |
+
# generate the new variable
|
| 348 |
+
my_view, counter = gen_tvar(counter)
|
| 349 |
+
symbols[n] = my_view
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
src_var = symbols[n.args[0]]
|
| 353 |
+
t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
|
| 354 |
+
t2_type = []
|
| 355 |
+
num_constraints = []
|
| 356 |
+
|
| 357 |
+
for t in t2:
|
| 358 |
+
if t == -1:
|
| 359 |
+
var, counter = gen_dvar(counter)
|
| 360 |
+
t2_type.append(var)
|
| 361 |
+
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
|
| 362 |
+
|
| 363 |
+
else:
|
| 364 |
+
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
|
| 365 |
+
t2_type.append(t)
|
| 366 |
+
|
| 367 |
+
t2_type = TensorType(t2_type) # type: ignore[assignment]
|
| 368 |
+
|
| 369 |
+
c1 = BinConstraintT(my_view, t2_type, op_eq)
|
| 370 |
+
c2 = CanReshape(src_var, t2_type)
|
| 371 |
+
|
| 372 |
+
# TODO: add the extra check mentioned here:
|
| 373 |
+
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
|
| 374 |
+
|
| 375 |
+
return [c1, c2] + num_constraints, counter # type: ignore[operator]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@register_inference_rule("size")
|
| 379 |
+
def size_inference_rule(n: Node, symbols, constraints, counter):
|
| 380 |
+
"""
|
| 381 |
+
The constraint is just lhs = rhs.
|
| 382 |
+
Ex: size = input_ids.size()
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
if len(n.args) == 1:
|
| 387 |
+
# generate the new variable
|
| 388 |
+
size, counter = gen_tvar(counter)
|
| 389 |
+
symbols[n] = size
|
| 390 |
+
input = symbols[n.args[0]]
|
| 391 |
+
c = BinConstraintT(input, size, op_eq)
|
| 392 |
+
return [c], counter
|
| 393 |
+
|
| 394 |
+
elif len(n.args) == 2:
|
| 395 |
+
# TODO: review this rule; should input = dyn; output = dyn be included here?
|
| 396 |
+
if isinstance(n.args[1], int):
|
| 397 |
+
# generate the new variable
|
| 398 |
+
size_index, counter = gen_dvar(counter)
|
| 399 |
+
symbols[n] = size_index
|
| 400 |
+
input = symbols[n.args[0]]
|
| 401 |
+
c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
|
| 402 |
+
c3 = BinConstraintD(0, size_index, op_leq)
|
| 403 |
+
|
| 404 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 405 |
+
output_dyn = BinConstraintD(size_index, Dyn, op_eq)
|
| 406 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 407 |
+
|
| 408 |
+
return [Disj([c1, Conj([Disj(c2), c3])])], counter
|
| 409 |
+
|
| 410 |
+
else:
|
| 411 |
+
raise NotImplementedError
|
| 412 |
+
|
| 413 |
+
else:
|
| 414 |
+
raise NotImplementedError
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def range_check(i, n):
|
| 418 |
+
"""
|
| 419 |
+
Checks if an index i is within range of a size n list
|
| 420 |
+
Args:
|
| 421 |
+
i: index
|
| 422 |
+
n: list size
|
| 423 |
+
|
| 424 |
+
Returns: Boolean
|
| 425 |
+
"""
|
| 426 |
+
if i >= 0:
|
| 427 |
+
return T() if i < n else F()
|
| 428 |
+
else:
|
| 429 |
+
return T() if i >= n else F()
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@register_inference_rule(torch.cumsum)
|
| 433 |
+
def cumsum_inference_rule(n: Node, symbols, constraints, counter):
|
| 434 |
+
"""
|
| 435 |
+
Input and output shapes should be equal
|
| 436 |
+
We should verify that the index is valid
|
| 437 |
+
"""
|
| 438 |
+
assert isinstance(n.args[0], Node)
|
| 439 |
+
arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
|
| 440 |
+
assert isinstance(arg_1, int)
|
| 441 |
+
|
| 442 |
+
output, counter = gen_tvar(counter)
|
| 443 |
+
symbols[n] = output
|
| 444 |
+
input = symbols[n.args[0]]
|
| 445 |
+
|
| 446 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 447 |
+
output_dyn = BinConstraintT(output, Dyn, op_eq)
|
| 448 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 449 |
+
c2 = []
|
| 450 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 451 |
+
new_dims, counter = gen_tensor_dims(i, counter)
|
| 452 |
+
|
| 453 |
+
nat_constraints = gen_nat_constraints(new_dims)
|
| 454 |
+
|
| 455 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
|
| 456 |
+
BinConstraintT(output, TensorType(new_dims), op_eq)] +
|
| 457 |
+
[range_check(arg_1, i)] + nat_constraints)
|
| 458 |
+
|
| 459 |
+
c2.append(c_tensor_i)
|
| 460 |
+
dyn_or_tensor = Disj([c1, Disj(c2)])
|
| 461 |
+
return [dyn_or_tensor], counter
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
@register_inference_rule(_assert_is_none)
|
| 465 |
+
def assert_inference_rule(n: Node, symbols, constraints, counter):
|
| 466 |
+
assert len(n.users) == 0
|
| 467 |
+
return [], counter
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@register_inference_rule(operator.getitem)
|
| 471 |
+
def getitem_inference_rule(n: Node, symbols, constraints, counter):
|
| 472 |
+
assert isinstance(n.args[0], Node)
|
| 473 |
+
|
| 474 |
+
# dimension output case
|
| 475 |
+
if isinstance(n.args[1], int):
|
| 476 |
+
# create and store the new dimension variable
|
| 477 |
+
get_item_output, counter = gen_dvar(counter)
|
| 478 |
+
symbols[n] = get_item_output
|
| 479 |
+
|
| 480 |
+
# retrieve arg variables
|
| 481 |
+
get_item_arg = symbols[n.args[0]]
|
| 482 |
+
assert isinstance(get_item_arg, TVar)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
# if the input is dynamic, we accept any index and return
|
| 486 |
+
# a dynamic dimension as output
|
| 487 |
+
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
|
| 488 |
+
output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
|
| 489 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 490 |
+
|
| 491 |
+
# if the input is a tensor,
|
| 492 |
+
# generate a getItem constraint which will be expanded based on the
|
| 493 |
+
# tensor dimension.
|
| 494 |
+
|
| 495 |
+
c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# since the output is a dimension, we make sure it's a natural number
|
| 499 |
+
# added as a conjunction to the disjunction of c2
|
| 500 |
+
c3 = BinConstraintD(0, get_item_output, op_leq)
|
| 501 |
+
return [Disj([c1, Conj([Disj(c2), c3])])], counter
|
| 502 |
+
|
| 503 |
+
# tensor output case
|
| 504 |
+
elif isinstance(n.args[1], tuple):
|
| 505 |
+
# create and store the new tensor variable
|
| 506 |
+
get_item_output, counter = gen_tvar(counter)
|
| 507 |
+
symbols[n] = get_item_output
|
| 508 |
+
|
| 509 |
+
# retrieve arg variables
|
| 510 |
+
if n.args[0] in symbols:
|
| 511 |
+
get_item_arg = symbols[n.args[0]]
|
| 512 |
+
assert isinstance(get_item_arg, TVar)
|
| 513 |
+
|
| 514 |
+
input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
|
| 515 |
+
output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
|
| 516 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 517 |
+
|
| 518 |
+
c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
|
| 519 |
+
for i in range(MAX_TENSOR_RANK)]
|
| 520 |
+
else:
|
| 521 |
+
# TODO: we should figure out why there is a key-error here.
|
| 522 |
+
return [], counter
|
| 523 |
+
|
| 524 |
+
return [Disj([c1, *c2])], counter
|
| 525 |
+
|
| 526 |
+
else:
|
| 527 |
+
raise RuntimeError('Method not yet implemented')
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@register_inference_rule(operator.gt)
|
| 531 |
+
def gt_inference_rule(n: Node, symbols, constraints, counter):
|
| 532 |
+
assert isinstance(n.args[0], (Node, int))
|
| 533 |
+
assert isinstance(n.args[1], (Node, int))
|
| 534 |
+
|
| 535 |
+
# We make sure this node will not be used again. We do not
|
| 536 |
+
# generate a constraint about that node. Only about the operands.
|
| 537 |
+
|
| 538 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 539 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 540 |
+
|
| 541 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 542 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 543 |
+
gt_tensor, counter = gen_tvar(counter)
|
| 544 |
+
symbols[n] = gt_tensor
|
| 545 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
|
| 546 |
+
|
| 547 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 548 |
+
# This is meant to be used for flow analysis only
|
| 549 |
+
gt_constraint = BinConstraintD(e1, e2, op_gt)
|
| 550 |
+
|
| 551 |
+
my_gt, counter = gen_bvar(counter)
|
| 552 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 553 |
+
return [equality_constraint], counter
|
| 554 |
+
|
| 555 |
+
else:
|
| 556 |
+
raise RuntimeError('Sort Mismatch')
|
| 557 |
+
|
| 558 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 559 |
+
if isinstance(e1, DVar):
|
| 560 |
+
# This is meant to be used for flow analysis only
|
| 561 |
+
gt_constraint = BinConstraintD(e1, e2, op_gt)
|
| 562 |
+
|
| 563 |
+
my_gt, counter = gen_bvar(counter)
|
| 564 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 565 |
+
return [equality_constraint], counter
|
| 566 |
+
|
| 567 |
+
elif isinstance(e1, TVar) and isinstance(e2, int):
|
| 568 |
+
# then we made the wrong assumption about the argument being a tensor
|
| 569 |
+
# so we should fix the assumption
|
| 570 |
+
warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
|
| 571 |
+
|
| 572 |
+
new_e1, counter = gen_dvar(counter)
|
| 573 |
+
symbols[n.args[0]] = new_e1
|
| 574 |
+
symbols[n.args[0]]
|
| 575 |
+
|
| 576 |
+
gt_constraint = BinConstraintD(new_e1, e2, op_gt)
|
| 577 |
+
|
| 578 |
+
my_gt, counter = gen_bvar(counter)
|
| 579 |
+
equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
|
| 580 |
+
return [equality_constraint], counter
|
| 581 |
+
|
| 582 |
+
else:
|
| 583 |
+
raise NotImplementedError('Method not yet implemented')
|
| 584 |
+
|
| 585 |
+
else:
|
| 586 |
+
raise NotImplementedError('Method not yet implemented')
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@register_inference_rule(operator.eq)
|
| 590 |
+
def eq_inference_rule(n: Node, symbols, constraints, counter):
|
| 591 |
+
assert isinstance(n.args[0], (Node, int))
|
| 592 |
+
assert isinstance(n.args[1], (Node, int))
|
| 593 |
+
|
| 594 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 595 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 596 |
+
|
| 597 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 598 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 599 |
+
eq_tensor, counter = gen_tvar(counter)
|
| 600 |
+
symbols[n] = eq_tensor
|
| 601 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
|
| 602 |
+
|
| 603 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 604 |
+
# This is meant to be used for flow analysis only
|
| 605 |
+
eq_constraint = BinConstraintD(e1, e2, op_eq)
|
| 606 |
+
|
| 607 |
+
my_eq, counter = gen_bvar(counter)
|
| 608 |
+
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
|
| 609 |
+
return [equality_constraint], counter
|
| 610 |
+
|
| 611 |
+
else:
|
| 612 |
+
raise RuntimeError('Sort Mismatch')
|
| 613 |
+
|
| 614 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 615 |
+
if isinstance(e1, DVar):
|
| 616 |
+
# This is meant to be used for flow analysis only
|
| 617 |
+
eq_constraint = BinConstraintD(e1, e2, op_eq)
|
| 618 |
+
|
| 619 |
+
my_eq, counter = gen_bvar(counter)
|
| 620 |
+
equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
|
| 621 |
+
return [equality_constraint], counter
|
| 622 |
+
else:
|
| 623 |
+
raise NotImplementedError('Method not yet implemented')
|
| 624 |
+
else:
|
| 625 |
+
raise NotImplementedError('Method not yet implemented')
|
| 626 |
+
|
| 627 |
+
@register_inference_rule(operator.ne)
|
| 628 |
+
def neq_inference_rule(n: Node, symbols, constraints, counter):
|
| 629 |
+
"""
|
| 630 |
+
Translates to inconsistent in gradual types.
|
| 631 |
+
To prove inequality, we should prove that
|
| 632 |
+
tensors are either different sizes or
|
| 633 |
+
disagree on at least one dimension
|
| 634 |
+
|
| 635 |
+
This is a WIP (works when the condition
|
| 636 |
+
is false. We are working on making this operation work
|
| 637 |
+
when the condition is true as well)
|
| 638 |
+
"""
|
| 639 |
+
assert isinstance(n.args[0], Node)
|
| 640 |
+
assert isinstance(n.args[1], tuple)
|
| 641 |
+
|
| 642 |
+
# implementing for size 3 and 4
|
| 643 |
+
if len(n.args[1]) == 3:
|
| 644 |
+
|
| 645 |
+
assert isinstance(n.args[1][0], (Node, int))
|
| 646 |
+
assert isinstance(n.args[1][1], (Node, int))
|
| 647 |
+
assert isinstance(n.args[1][2], (Node, int))
|
| 648 |
+
|
| 649 |
+
lhs = symbols[n.args[0]]
|
| 650 |
+
|
| 651 |
+
b, counter = gen_tensor_dims(4, counter)
|
| 652 |
+
input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
|
| 653 |
+
|
| 654 |
+
d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
|
| 655 |
+
d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
|
| 656 |
+
d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
|
| 657 |
+
|
| 658 |
+
# dimensions not equal
|
| 659 |
+
my_ne, counter = gen_bvar(counter)
|
| 660 |
+
neq_1 = BinConstraintD(d1, b[0], op_neq)
|
| 661 |
+
neq_2 = BinConstraintD(d2, b[1], op_neq)
|
| 662 |
+
neq_3 = BinConstraintD(d3, b[2], op_neq)
|
| 663 |
+
|
| 664 |
+
# dimensions inconsistent
|
| 665 |
+
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
|
| 666 |
+
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
|
| 667 |
+
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
|
| 668 |
+
|
| 669 |
+
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
|
| 670 |
+
|
| 671 |
+
# we are covering size 3 and 4 only for now
|
| 672 |
+
ne_constraint = Conj([input_is_size3, dims_inconsistent])
|
| 673 |
+
|
| 674 |
+
my_ne, counter = gen_bvar(counter)
|
| 675 |
+
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
| 676 |
+
|
| 677 |
+
elif len(n.args[1]) == 4:
|
| 678 |
+
|
| 679 |
+
assert isinstance(n.args[1][0], (Node, int))
|
| 680 |
+
assert isinstance(n.args[1][1], (Node, int))
|
| 681 |
+
assert isinstance(n.args[1][2], (Node, int))
|
| 682 |
+
assert isinstance(n.args[1][3], (Node, int))
|
| 683 |
+
|
| 684 |
+
lhs = symbols[n.args[0]]
|
| 685 |
+
|
| 686 |
+
b1, counter = gen_dvar(counter)
|
| 687 |
+
b2, counter = gen_dvar(counter)
|
| 688 |
+
b3, counter = gen_dvar(counter)
|
| 689 |
+
b4, counter = gen_dvar(counter)
|
| 690 |
+
|
| 691 |
+
input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
|
| 692 |
+
|
| 693 |
+
d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
|
| 694 |
+
d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
|
| 695 |
+
d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
|
| 696 |
+
d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
|
| 697 |
+
|
| 698 |
+
# dimensions not equal
|
| 699 |
+
my_ne, counter = gen_bvar(counter)
|
| 700 |
+
neq_1 = BinConstraintD(d1, b1, op_neq)
|
| 701 |
+
neq_2 = BinConstraintD(d2, b2, op_neq)
|
| 702 |
+
neq_3 = BinConstraintD(d3, b3, op_neq)
|
| 703 |
+
neq_4 = BinConstraintD(d4, b4, op_neq)
|
| 704 |
+
|
| 705 |
+
# dimensions to inconsistent
|
| 706 |
+
dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
|
| 707 |
+
dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
|
| 708 |
+
dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
|
| 709 |
+
dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
|
| 710 |
+
|
| 711 |
+
dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
|
| 712 |
+
|
| 713 |
+
ne_constraint = Conj([input_is_size4, dims_inconsistent])
|
| 714 |
+
|
| 715 |
+
my_ne, counter = gen_bvar(counter)
|
| 716 |
+
|
| 717 |
+
equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
|
| 718 |
+
|
| 719 |
+
else:
|
| 720 |
+
raise NotImplementedError('Method not yet implemented')
|
| 721 |
+
|
| 722 |
+
return [equality_constraint], counter
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
@register_inference_rule(operator.lt)
|
| 726 |
+
def lt_inference_rule(n: Node, symbols, constraints, counter):
|
| 727 |
+
assert isinstance(n.args[0], (Node, int))
|
| 728 |
+
assert isinstance(n.args[1], (Node, int))
|
| 729 |
+
|
| 730 |
+
# We make sure this node will not be used again. We do not
|
| 731 |
+
# generate a constraint about that node. Only about the operands.
|
| 732 |
+
|
| 733 |
+
e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
|
| 734 |
+
e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
|
| 735 |
+
|
| 736 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 737 |
+
if isinstance(e1, TVar) and isinstance(e2, TVar):
|
| 738 |
+
lt_tensor, counter = gen_tvar(counter)
|
| 739 |
+
symbols[n] = lt_tensor
|
| 740 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
|
| 741 |
+
|
| 742 |
+
elif isinstance(e1, DVar) and isinstance(e2, DVar):
|
| 743 |
+
# This is meant to be used for flow analysis only
|
| 744 |
+
lt_constraint = BinConstraintD(e1, e2, op_lt)
|
| 745 |
+
|
| 746 |
+
my_lt, counter = gen_bvar(counter)
|
| 747 |
+
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
|
| 748 |
+
return [equality_constraint], counter
|
| 749 |
+
|
| 750 |
+
else:
|
| 751 |
+
raise RuntimeError('Sort Mismatch')
|
| 752 |
+
|
| 753 |
+
elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
|
| 754 |
+
if isinstance(e1, DVar):
|
| 755 |
+
# This is meant to be used for flow analysis only
|
| 756 |
+
lt_constraint = BinConstraintD(e1, e2, op_lt)
|
| 757 |
+
|
| 758 |
+
my_lt, counter = gen_bvar(counter)
|
| 759 |
+
equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
|
| 760 |
+
return [equality_constraint], counter
|
| 761 |
+
else:
|
| 762 |
+
raise NotImplementedError('Method not yet implemented')
|
| 763 |
+
|
| 764 |
+
else:
|
| 765 |
+
raise NotImplementedError('Method not yet implemented')
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
@register_inference_rule(torch.full)
|
| 769 |
+
def full_inference_rule(n: Node, symbols, constraints, counter):
|
| 770 |
+
full, counter = gen_tvar(counter)
|
| 771 |
+
symbols[n] = full
|
| 772 |
+
res = []
|
| 773 |
+
|
| 774 |
+
assert isinstance(n.args[0], Iterable)
|
| 775 |
+
for arg in n.args[0]:
|
| 776 |
+
dim = arg if isinstance(arg, int) else symbols[arg]
|
| 777 |
+
res.append(dim)
|
| 778 |
+
c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type]
|
| 779 |
+
return [c], counter
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
# TODO normalize index
|
| 783 |
+
@register_inference_rule(torch.arange)
|
| 784 |
+
def arange_inference_rule(n: Node, symbols, constraints, counter):
|
| 785 |
+
start = 0
|
| 786 |
+
step = 1
|
| 787 |
+
|
| 788 |
+
if len(n.args) == 1:
|
| 789 |
+
end = symbols[n.args[0]]
|
| 790 |
+
else:
|
| 791 |
+
raise NotImplementedError('Not yet implemented')
|
| 792 |
+
|
| 793 |
+
# int((end - start) / step)
|
| 794 |
+
d1, counter = gen_dvar(counter)
|
| 795 |
+
size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
|
| 796 |
+
arange, counter = gen_tvar(counter)
|
| 797 |
+
symbols[n] = arange
|
| 798 |
+
|
| 799 |
+
# either the a parameter is a number or it is Dyn
|
| 800 |
+
c1 = Disj([BinConstraintD(end, Dyn, op_eq),
|
| 801 |
+
BinConstraintD(start, Dyn, op_eq),
|
| 802 |
+
BinConstraintD(step, Dyn, op_eq)])
|
| 803 |
+
c2 = BinConstraintD(d1, Dyn, op_eq)
|
| 804 |
+
both_dyn = Conj([c1, c2])
|
| 805 |
+
|
| 806 |
+
c11 = Conj([BinConstraintD(end, Dyn, op_neq),
|
| 807 |
+
BinConstraintD(start, Dyn, op_neq),
|
| 808 |
+
BinConstraintD(step, Dyn, op_neq)])
|
| 809 |
+
c22 = BinConstraintD(d1, Dyn, op_neq)
|
| 810 |
+
both_numbers = Conj([c11, c22, size_constraint])
|
| 811 |
+
|
| 812 |
+
return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
|
| 813 |
+
|
| 814 |
+
def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
|
| 815 |
+
# additional vars that don't correspond to expressions
|
| 816 |
+
e11, counter = gen_tvar(counter)
|
| 817 |
+
e22, counter = gen_tvar(counter)
|
| 818 |
+
|
| 819 |
+
# generate constraints
|
| 820 |
+
c1 = TGreatestUpperBound(output_var, e11, e22)
|
| 821 |
+
c2 = ApplyBroadcasting(e11, e22, e1, e2)
|
| 822 |
+
c3 = BinConstraintT(e11, e22, op_consistency)
|
| 823 |
+
return [c1, c2, c3], counter
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
@register_inference_rule(operator.mul)
|
| 827 |
+
@register_inference_rule(torch.ne)
|
| 828 |
+
@register_inference_rule("ne")
|
| 829 |
+
@register_inference_rule(torch.add)
|
| 830 |
+
@register_inference_rule(operator.add)
|
| 831 |
+
def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
|
| 832 |
+
|
| 833 |
+
op_code = None
|
| 834 |
+
if n.target == operator.add or n.target == torch.add:
|
| 835 |
+
op_code = op_add
|
| 836 |
+
elif n.target == operator.mul:
|
| 837 |
+
op_code = op_mul
|
| 838 |
+
|
| 839 |
+
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
|
| 840 |
+
if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
|
| 841 |
+
my_output, counter = gen_tvar(counter)
|
| 842 |
+
symbols[n] = my_output
|
| 843 |
+
e1 = symbols[n.args[0]]
|
| 844 |
+
e2 = symbols[n.args[1]]
|
| 845 |
+
|
| 846 |
+
return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
|
| 847 |
+
else:
|
| 848 |
+
raise NotImplementedError('Method not yet implemented')
|
| 849 |
+
|
| 850 |
+
elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
|
| 851 |
+
if isinstance(symbols[n.args[0]], TVar):
|
| 852 |
+
my_output, counter = gen_tvar(counter)
|
| 853 |
+
symbols[n] = my_output
|
| 854 |
+
e1 = symbols[n.args[0]]
|
| 855 |
+
return [BinConstraintT(my_output, e1, op_eq)], counter
|
| 856 |
+
elif isinstance(symbols[n.args[0]], DVar):
|
| 857 |
+
my_output, counter = gen_dvar(counter)
|
| 858 |
+
symbols[n] = my_output
|
| 859 |
+
e1 = symbols[n.args[0]]
|
| 860 |
+
|
| 861 |
+
# we will propagate the runtime value here since this is regular addition
|
| 862 |
+
c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
|
| 863 |
+
BinConstraintD(0, my_output, op_leq)])
|
| 864 |
+
return [c], counter
|
| 865 |
+
|
| 866 |
+
elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
|
| 867 |
+
if isinstance(symbols[n.args[1]], TVar):
|
| 868 |
+
my_output, counter = gen_tvar(counter)
|
| 869 |
+
symbols[n] = my_output
|
| 870 |
+
e2 = symbols[n.args[1]]
|
| 871 |
+
return [BinConstraintT(my_output, e2, op_eq)], counter
|
| 872 |
+
elif isinstance(symbols[n.args[1]], DVar):
|
| 873 |
+
my_output, counter = gen_dvar(counter)
|
| 874 |
+
symbols[n] = my_output
|
| 875 |
+
e2 = symbols[n.args[1]]
|
| 876 |
+
|
| 877 |
+
# we will propagate the runtime value here since this is regular addition
|
| 878 |
+
c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
|
| 879 |
+
BinConstraintD(0, my_output, op_leq)])
|
| 880 |
+
return [c], counter
|
| 881 |
+
|
| 882 |
+
else:
|
| 883 |
+
raise NotImplementedError('Method not yet implemented')
|
| 884 |
+
|
| 885 |
+
else:
|
| 886 |
+
# TODO generate add constraints for scalar addition
|
| 887 |
+
raise NotImplementedError('Addition not yet implemented')
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@register_inference_rule(torch.flatten)
|
| 891 |
+
def flatten_inference_rule(n: Node, symbols, constraints, counter):
|
| 892 |
+
assert isinstance(n.args[0], Node)
|
| 893 |
+
|
| 894 |
+
# generate the new variable
|
| 895 |
+
flattened, counter = gen_tvar(counter)
|
| 896 |
+
symbols[n] = flattened
|
| 897 |
+
|
| 898 |
+
input = symbols[n.args[0]]
|
| 899 |
+
|
| 900 |
+
# set the default start and end dims
|
| 901 |
+
start_dim = 1
|
| 902 |
+
end_dim = -1
|
| 903 |
+
|
| 904 |
+
if len(n.args) > 1:
|
| 905 |
+
assert isinstance(n.args[1], int)
|
| 906 |
+
start_dim = n.args[1]
|
| 907 |
+
|
| 908 |
+
if len(n.args) > 2:
|
| 909 |
+
assert isinstance(n.args[2], int)
|
| 910 |
+
end_dim = n.args[2]
|
| 911 |
+
|
| 912 |
+
c1 = BinConstraintT(input, Dyn, op_eq)
|
| 913 |
+
c2 = BinConstraintT(flattened, Dyn, op_eq)
|
| 914 |
+
both_dyn = Conj([c1, c2])
|
| 915 |
+
|
| 916 |
+
const = []
|
| 917 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 918 |
+
c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
|
| 919 |
+
const.append(c)
|
| 920 |
+
|
| 921 |
+
return [Disj([both_dyn, *const])], counter
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
@register_inference_rule(torch.nn.functional.layer_norm)
|
| 925 |
+
def layer_norm_functional(n: Node, symbols, constraints, counter):
|
| 926 |
+
"""
|
| 927 |
+
We generate the constraint: input = output
|
| 928 |
+
"""
|
| 929 |
+
assert isinstance(n.args[0], Node)
|
| 930 |
+
return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
@register_inference_rule(torch.nn.LayerNorm)
|
| 934 |
+
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 935 |
+
"""
|
| 936 |
+
Input and output shapes should be equal.
|
| 937 |
+
Input should be consistent with the normalized_shape
|
| 938 |
+
"""
|
| 939 |
+
assert isinstance(n.args[0], Node)
|
| 940 |
+
return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
|
| 944 |
+
output, counter = gen_tvar(counter)
|
| 945 |
+
symbols[n] = output
|
| 946 |
+
input = symbols[n.args[0]]
|
| 947 |
+
|
| 948 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 949 |
+
output_dyn = BinConstraintT(output, Dyn, op_eq)
|
| 950 |
+
|
| 951 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 952 |
+
|
| 953 |
+
c2 = []
|
| 954 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 955 |
+
new_dims_rhs, counter = gen_tensor_dims(i, counter)
|
| 956 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs)
|
| 957 |
+
|
| 958 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
|
| 959 |
+
BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
|
| 960 |
+
add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
|
| 961 |
+
nat_constraints)
|
| 962 |
+
c2.append(c_tensor_i)
|
| 963 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 964 |
+
|
| 965 |
+
@register_inference_rule(torch.nn.Dropout)
|
| 966 |
+
@register_inference_rule(torch.nn.ReLU)
|
| 967 |
+
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 968 |
+
"""
|
| 969 |
+
Input and output shapes should be equal.
|
| 970 |
+
"""
|
| 971 |
+
assert isinstance(n.args[0], Node)
|
| 972 |
+
output, counter = gen_tvar(counter)
|
| 973 |
+
symbols[n] = output
|
| 974 |
+
input = symbols[n.args[0]]
|
| 975 |
+
assert isinstance(input, TVar)
|
| 976 |
+
return [BinConstraintT(input, output, op_eq)], counter
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
@register_inference_rule(torch.nn.Linear)
|
| 980 |
+
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 981 |
+
"""
|
| 982 |
+
Input and output sizes should be the same except for the last dimension
|
| 983 |
+
If the input is Dyn, then so should the output
|
| 984 |
+
"""
|
| 985 |
+
assert isinstance(n.args[0], Node)
|
| 986 |
+
return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
|
| 987 |
+
|
| 988 |
+
|
| 989 |
+
@register_inference_rule("dim") # type: ignore[attr-defined]
|
| 990 |
+
def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
|
| 991 |
+
assert isinstance(n.args[0], Node)
|
| 992 |
+
my_dim, counter = gen_dvar(counter)
|
| 993 |
+
symbols[n] = my_dim
|
| 994 |
+
input = symbols[n.args[0]]
|
| 995 |
+
|
| 996 |
+
input_dyn = BinConstraintT(input, Dyn, op_eq)
|
| 997 |
+
output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
|
| 998 |
+
|
| 999 |
+
c1 = []
|
| 1000 |
+
|
| 1001 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 1002 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 1003 |
+
|
| 1004 |
+
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
|
| 1005 |
+
BinConstraintD(my_dim, i, op_eq)])
|
| 1006 |
+
c1.append(c_tensor_i)
|
| 1007 |
+
|
| 1008 |
+
return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined]
|
| 1012 |
+
def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
|
| 1013 |
+
assert isinstance(n.args[0], Node)
|
| 1014 |
+
weight_dims, counter = gen_tensor_dims(2, counter)
|
| 1015 |
+
equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
|
| 1016 |
+
constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
|
| 1017 |
+
return [equality_constraint] + constraints, counter
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def linear_constraints(n: Node, in_features, out_features, symbols, counter):
|
| 1021 |
+
linear_output, counter = gen_tvar(counter)
|
| 1022 |
+
symbols[n] = linear_output
|
| 1023 |
+
linear_input = symbols[n.args[0]]
|
| 1024 |
+
|
| 1025 |
+
input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
|
| 1026 |
+
output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
|
| 1027 |
+
|
| 1028 |
+
c1 = Conj([input_dyn, output_dyn])
|
| 1029 |
+
|
| 1030 |
+
c2 = []
|
| 1031 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 1032 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 1033 |
+
new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
|
| 1034 |
+
|
| 1035 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
|
| 1036 |
+
|
| 1037 |
+
c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
|
| 1038 |
+
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
|
| 1039 |
+
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
|
| 1040 |
+
nat_constraints)
|
| 1041 |
+
c2.append(c_tensor_i)
|
| 1042 |
+
return [Disj([c1, Disj(c2)])], counter
|
| 1043 |
+
|
| 1044 |
+
def add_layer_norm_constraints(input_dim, normalized_dim):
|
| 1045 |
+
"""
|
| 1046 |
+
The constraints say that the type has te form: [*, 1024, 1024]
|
| 1047 |
+
while the normalized_dim have the form [1024, 1024]
|
| 1048 |
+
Args:
|
| 1049 |
+
input_dim: Input shape of layer norm
|
| 1050 |
+
normalized_dim: normalized_dim parameter of the module instance
|
| 1051 |
+
|
| 1052 |
+
"""
|
| 1053 |
+
|
| 1054 |
+
# in this case we return false since there's a pattern mismatch
|
| 1055 |
+
if len(normalized_dim) > len(input_dim):
|
| 1056 |
+
return [F()]
|
| 1057 |
+
|
| 1058 |
+
else:
|
| 1059 |
+
constraints = []
|
| 1060 |
+
for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
|
| 1061 |
+
constraints.append(BinConstraintD(i, n, op_consistency))
|
| 1062 |
+
return constraints
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
def add_linear_constraints(dims1, dims2, in_features, out_features):
|
| 1066 |
+
assert len(dims1) == len(dims2)
|
| 1067 |
+
constraints = []
|
| 1068 |
+
for i in range(len(dims1)):
|
| 1069 |
+
if i == len(dims1) - 1:
|
| 1070 |
+
constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
|
| 1071 |
+
constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
|
| 1072 |
+
else:
|
| 1073 |
+
constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
|
| 1074 |
+
|
| 1075 |
+
return constraints
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
@register_inference_rule(torch.reshape)
|
| 1079 |
+
def reshape_inference_rule(n: Node, symbols, constraints, counter):
|
| 1080 |
+
assert isinstance(n.args[0], Node)
|
| 1081 |
+
|
| 1082 |
+
# generate the new variable
|
| 1083 |
+
my_reshape, counter = gen_tvar(counter)
|
| 1084 |
+
symbols[n] = my_reshape
|
| 1085 |
+
|
| 1086 |
+
src_var = symbols[n.args[0]]
|
| 1087 |
+
t2 = n.args[1]
|
| 1088 |
+
t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr]
|
| 1089 |
+
c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr]
|
| 1090 |
+
c2 = CanReshape(src_var, t2_type)
|
| 1091 |
+
|
| 1092 |
+
return [c1, c2], counter
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@register_inference_rule(BatchNorm2d)
|
| 1096 |
+
def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1097 |
+
assert isinstance(n.args[0], Node)
|
| 1098 |
+
|
| 1099 |
+
# generate the new variable
|
| 1100 |
+
batchnorm_output, counter = gen_tvar(counter)
|
| 1101 |
+
symbols[n] = batchnorm_output
|
| 1102 |
+
batchnorm_input = symbols[n.args[0]]
|
| 1103 |
+
|
| 1104 |
+
# dim vars
|
| 1105 |
+
d1, counter = gen_dvar(counter)
|
| 1106 |
+
d2, counter = gen_dvar(counter)
|
| 1107 |
+
d3, counter = gen_dvar(counter)
|
| 1108 |
+
d4, counter = gen_dvar(counter)
|
| 1109 |
+
|
| 1110 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1111 |
+
|
| 1112 |
+
c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1113 |
+
c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
|
| 1114 |
+
return [c1, c2, *nat_constraints], counter
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
@register_inference_rule(torch.nn.AdaptiveAvgPool2d)
|
| 1118 |
+
def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1119 |
+
assert isinstance(n.args[0], Node)
|
| 1120 |
+
|
| 1121 |
+
avg_pool, counter = gen_tvar(counter)
|
| 1122 |
+
|
| 1123 |
+
symbols[n] = avg_pool
|
| 1124 |
+
input_var = symbols[n.args[0]]
|
| 1125 |
+
|
| 1126 |
+
# dim vars
|
| 1127 |
+
d1, counter = gen_dvar(counter)
|
| 1128 |
+
d2, counter = gen_dvar(counter)
|
| 1129 |
+
d3, counter = gen_dvar(counter)
|
| 1130 |
+
d4, counter = gen_dvar(counter)
|
| 1131 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1132 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1133 |
+
c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
|
| 1134 |
+
|
| 1135 |
+
return [c1, c2, *nat_constraints], counter
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
@register_inference_rule(Conv2d)
|
| 1139 |
+
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1140 |
+
assert isinstance(n.args[0], Node)
|
| 1141 |
+
|
| 1142 |
+
my_conv, counter = gen_tvar(counter)
|
| 1143 |
+
symbols[n] = my_conv
|
| 1144 |
+
input_var = symbols[n.args[0]]
|
| 1145 |
+
|
| 1146 |
+
# dim vars
|
| 1147 |
+
[d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
|
| 1148 |
+
|
| 1149 |
+
# c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
|
| 1150 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1151 |
+
|
| 1152 |
+
# c2 = DConsistency(module_instance.in_channels, d2)
|
| 1153 |
+
c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
|
| 1154 |
+
|
| 1155 |
+
c3 = CalcConv(my_conv, input_var,
|
| 1156 |
+
module_instance.out_channels,
|
| 1157 |
+
module_instance.kernel_size,
|
| 1158 |
+
module_instance.padding,
|
| 1159 |
+
module_instance.stride,
|
| 1160 |
+
module_instance.dilation, [d1, d2, d3, d4])
|
| 1161 |
+
|
| 1162 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1163 |
+
|
| 1164 |
+
return [c1, c2, c3, *nat_constraints], counter
|
| 1165 |
+
|
| 1166 |
+
|
| 1167 |
+
@register_inference_rule(torch.nn.MaxPool2d)
|
| 1168 |
+
def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
|
| 1169 |
+
assert isinstance(n.args[0], Node)
|
| 1170 |
+
maxpool, counter = gen_tvar(counter)
|
| 1171 |
+
symbols[n] = maxpool
|
| 1172 |
+
input_var = symbols[n.args[0]]
|
| 1173 |
+
|
| 1174 |
+
# dim vars
|
| 1175 |
+
[d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
|
| 1176 |
+
|
| 1177 |
+
c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
|
| 1178 |
+
|
| 1179 |
+
c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
|
| 1180 |
+
module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
|
| 1181 |
+
|
| 1182 |
+
nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
|
| 1183 |
+
|
| 1184 |
+
return [c1, c2, *nat_constraints], counter
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
class ConstraintGenerator:
|
| 1188 |
+
def __init__(self, traced, graph=None):
|
| 1189 |
+
self.traced = traced # traced or tracer.root
|
| 1190 |
+
self.traced_params = dict(self.traced.named_parameters())
|
| 1191 |
+
self.constraints = []
|
| 1192 |
+
self.symbol_dict = {}
|
| 1193 |
+
self.graph = traced.graph if hasattr(traced, 'graph') else graph
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
def generate_constraints(self, counter=0):
|
| 1197 |
+
"""
|
| 1198 |
+
Iterate through every node and generate constraints
|
| 1199 |
+
Effect: self.constraints will be populated with the final constraints
|
| 1200 |
+
"""
|
| 1201 |
+
graph = self.graph
|
| 1202 |
+
|
| 1203 |
+
all_constraints = []
|
| 1204 |
+
|
| 1205 |
+
for n in graph.nodes:
|
| 1206 |
+
(constraints, counter) = self.generate_constraints_node(n, counter)
|
| 1207 |
+
all_constraints += constraints
|
| 1208 |
+
|
| 1209 |
+
return Conj(all_constraints), counter
|
| 1210 |
+
|
| 1211 |
+
def generate_constraints_node(self, n: Node, counter):
|
| 1212 |
+
"""
|
| 1213 |
+
Generate constraints the given node:
|
| 1214 |
+
Currently supported operations:
|
| 1215 |
+
- Reshape
|
| 1216 |
+
- Add
|
| 1217 |
+
- conv2d
|
| 1218 |
+
"""
|
| 1219 |
+
|
| 1220 |
+
if n.op == 'placeholder':
|
| 1221 |
+
x, counter = gen_tvar(counter)
|
| 1222 |
+
self.symbol_dict[n] = x
|
| 1223 |
+
|
| 1224 |
+
my_type = n.type
|
| 1225 |
+
|
| 1226 |
+
if n.type != Dyn and (not isinstance(n.type, TensorType)):
|
| 1227 |
+
if n.type == torch.nn.parameter.Parameter:
|
| 1228 |
+
# since we have a parameter, the shape must be static
|
| 1229 |
+
assert 'example_value' in n.meta
|
| 1230 |
+
my_type = TensorType(n.meta['example_value'].size())
|
| 1231 |
+
else:
|
| 1232 |
+
my_type = Dyn
|
| 1233 |
+
|
| 1234 |
+
c1 = BinConstraintT(my_type, x, op_precision)
|
| 1235 |
+
c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
|
| 1236 |
+
return [c1, c2], counter
|
| 1237 |
+
|
| 1238 |
+
elif n.op == 'call_function':
|
| 1239 |
+
if n.target in _INFERENCE_RULES:
|
| 1240 |
+
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
| 1241 |
+
else:
|
| 1242 |
+
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
| 1243 |
+
|
| 1244 |
+
elif n.op == 'call_module':
|
| 1245 |
+
|
| 1246 |
+
module_instance = self.traced.get_submodule(n.target)
|
| 1247 |
+
if type(module_instance) in _INFERENCE_RULES:
|
| 1248 |
+
return _INFERENCE_RULES[type(module_instance)](n,
|
| 1249 |
+
module_instance,
|
| 1250 |
+
self.symbol_dict,
|
| 1251 |
+
self.constraints, counter)
|
| 1252 |
+
else:
|
| 1253 |
+
raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
|
| 1254 |
+
|
| 1255 |
+
elif n.op == 'call_method':
|
| 1256 |
+
if n.target in _INFERENCE_RULES:
|
| 1257 |
+
return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
|
| 1258 |
+
else:
|
| 1259 |
+
raise RuntimeError(f'No inference rule registered for target {n.target}!')
|
| 1260 |
+
|
| 1261 |
+
elif n.op == 'get_attr':
|
| 1262 |
+
t = self.traced_params.get(n.target, None)
|
| 1263 |
+
|
| 1264 |
+
if isinstance(t, torch.Tensor):
|
| 1265 |
+
if len(t.shape) > 0:
|
| 1266 |
+
res = list(t.shape)
|
| 1267 |
+
attr_type = TensorType(res)
|
| 1268 |
+
output, counter = gen_tvar(counter)
|
| 1269 |
+
self.symbol_dict[n] = output
|
| 1270 |
+
return [BinConstraintT(output, attr_type, op_eq)], counter
|
| 1271 |
+
else:
|
| 1272 |
+
# scalar?
|
| 1273 |
+
return [], counter
|
| 1274 |
+
else:
|
| 1275 |
+
return [], counter
|
| 1276 |
+
|
| 1277 |
+
elif n.op == 'output':
|
| 1278 |
+
return [], counter
|
| 1279 |
+
|
| 1280 |
+
else:
|
| 1281 |
+
raise NotImplementedError(f"Method {n.op} not yet implemented")
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py
ADDED
|
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: ignore-errors
|
| 2 |
+
import copy
|
| 3 |
+
import itertools
|
| 4 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
|
| 5 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
|
| 6 |
+
Transpose
|
| 7 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
|
| 8 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
|
| 9 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
|
| 10 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
|
| 11 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
|
| 12 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
|
| 13 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
|
| 14 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
|
| 15 |
+
from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
|
| 16 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 17 |
+
from typing import Callable, Dict, List
|
| 18 |
+
|
| 19 |
+
_TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def register_transformation_rule(call_target):
|
| 23 |
+
def register(fn):
|
| 24 |
+
if call_target in _TRANSFORMATION_RULES:
|
| 25 |
+
raise RuntimeError(f'Transformation rule already registered for {call_target}!')
|
| 26 |
+
_TRANSFORMATION_RULES[call_target] = fn
|
| 27 |
+
return fn
|
| 28 |
+
return register
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def valid_index(index, dims):
|
| 32 |
+
"""
|
| 33 |
+
Given a list of dimensions, checks if an index is valid in the list
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
dims[index]
|
| 37 |
+
return T()
|
| 38 |
+
except IndexError:
|
| 39 |
+
return F()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@register_transformation_rule(Transpose)
|
| 43 |
+
def transform_transpose(constraint, counter):
|
| 44 |
+
"""
|
| 45 |
+
Similar to a sequence of two index-selects
|
| 46 |
+
"""
|
| 47 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 48 |
+
is_valid_index1 = valid_index(constraint.index1, dims)
|
| 49 |
+
is_valid_index2 = valid_index(constraint.index2, dims)
|
| 50 |
+
new_dims = copy.deepcopy(dims)
|
| 51 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 52 |
+
|
| 53 |
+
if is_valid_index1 == T() and is_valid_index2 == T():
|
| 54 |
+
new_dims[constraint.index1] = dims[constraint.index2]
|
| 55 |
+
new_dims[constraint.index2] = dims[constraint.index1]
|
| 56 |
+
|
| 57 |
+
transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 58 |
+
*nat_constraints,
|
| 59 |
+
is_valid_index1, is_valid_index2,
|
| 60 |
+
BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
|
| 61 |
+
return transformed_constraint, counter
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@register_transformation_rule(IndexSelect)
|
| 65 |
+
def transform_index_select(constraint, counter):
|
| 66 |
+
"""
|
| 67 |
+
The constraints consider the given tensor size, checks if the index is valid
|
| 68 |
+
and if so, generates a constraint for replacing the input dimension
|
| 69 |
+
with the required dimension
|
| 70 |
+
"""
|
| 71 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 72 |
+
is_valid_index = valid_index(constraint.index, dims)
|
| 73 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 74 |
+
|
| 75 |
+
# if the index is valid then replace the input dimension with the new dimension
|
| 76 |
+
# otherwise the dimension will not be replaced and the clause will contain False
|
| 77 |
+
if is_valid_index == T():
|
| 78 |
+
new_dims = copy.deepcopy(dims)
|
| 79 |
+
new_dims[constraint.index] = constraint.dim_replace
|
| 80 |
+
|
| 81 |
+
transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 82 |
+
*nat_constraints,
|
| 83 |
+
is_valid_index,
|
| 84 |
+
BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
|
| 85 |
+
|
| 86 |
+
# print(constraints)
|
| 87 |
+
return transformed_constraint, counter
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@register_transformation_rule(GetItem)
|
| 91 |
+
def transform_get_item(constraint, counter):
|
| 92 |
+
"""
|
| 93 |
+
generate an equality of the form:
|
| 94 |
+
t = [a1, ..., an]
|
| 95 |
+
then generate constraints that check if the given index is valid
|
| 96 |
+
given this particular tensor size.
|
| 97 |
+
If the index is valid, generate a constraint to get the item
|
| 98 |
+
Note that we already handled the Dyn input case in the previous
|
| 99 |
+
step.
|
| 100 |
+
Args:
|
| 101 |
+
constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
|
| 102 |
+
counter: variable tracking
|
| 103 |
+
Returns: simplified constraints for GetItem
|
| 104 |
+
|
| 105 |
+
"""
|
| 106 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 107 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
is_valid_index = valid_index(constraint.index, dims)
|
| 111 |
+
|
| 112 |
+
all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 113 |
+
*nat_constraints,
|
| 114 |
+
is_valid_index]
|
| 115 |
+
|
| 116 |
+
# if the index is valid, we generate a constraint for getting an item
|
| 117 |
+
# otherwise this clause will have been UNSAT due to the wrong index
|
| 118 |
+
if is_valid_index == T():
|
| 119 |
+
all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
|
| 120 |
+
|
| 121 |
+
return Conj(all_constraints), counter
|
| 122 |
+
|
| 123 |
+
def valid_index_tensor(index, dims):
|
| 124 |
+
"""
|
| 125 |
+
if the slice instances exceed the length of the dimensions
|
| 126 |
+
then this is a type error so we return False
|
| 127 |
+
"""
|
| 128 |
+
slice_count = 0
|
| 129 |
+
for s in index:
|
| 130 |
+
if isinstance(s, slice):
|
| 131 |
+
slice_count += 1
|
| 132 |
+
if slice_count > len(dims):
|
| 133 |
+
return F()
|
| 134 |
+
else:
|
| 135 |
+
return T()
|
| 136 |
+
|
| 137 |
+
@register_transformation_rule(GetItemTensor)
|
| 138 |
+
def transform_get_item_tensor(constraint, counter):
|
| 139 |
+
"""
|
| 140 |
+
When the index is a tuple, then the output will be a tensor
|
| 141 |
+
TODO: we have to check if this is the case for all HF models
|
| 142 |
+
|
| 143 |
+
The cases we are covering here are a tuple with one of:
|
| 144 |
+
- slice with default argument
|
| 145 |
+
- None
|
| 146 |
+
|
| 147 |
+
None appends 1 to the input tensor dimensions
|
| 148 |
+
so each occurrence of 'None' increases the rank by 1
|
| 149 |
+
|
| 150 |
+
slice with default arguments does not change the rank
|
| 151 |
+
"""
|
| 152 |
+
assert isinstance(constraint.index_tuple, tuple)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# generate a result tensor of the expected size
|
| 156 |
+
dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
|
| 157 |
+
nat_constraints = gen_nat_constraints(dims)
|
| 158 |
+
|
| 159 |
+
# generate a place-holder list of the right rank
|
| 160 |
+
# where "slice" does not contribute to the rank and "None" does
|
| 161 |
+
none_c = constraint.index_tuple.count(None)
|
| 162 |
+
resulting_tensor_dims = (none_c + len(dims)) * [None]
|
| 163 |
+
|
| 164 |
+
dim_index = 0
|
| 165 |
+
for i in range(len(constraint.index_tuple)):
|
| 166 |
+
|
| 167 |
+
# append 1 to the right location of the resulting tensor
|
| 168 |
+
if constraint.index_tuple[i] is None:
|
| 169 |
+
resulting_tensor_dims[i] = 1
|
| 170 |
+
|
| 171 |
+
elif constraint.index_tuple[i] == slice(None, None, None):
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
raise NotImplementedError('Method not yet implemented')
|
| 176 |
+
|
| 177 |
+
# append the remaining dimensions to the right location
|
| 178 |
+
dim_index = 0
|
| 179 |
+
for i in range(len(resulting_tensor_dims)):
|
| 180 |
+
if resulting_tensor_dims[i] is None:
|
| 181 |
+
resulting_tensor_dims[i] = dims[dim_index]
|
| 182 |
+
dim_index += 1
|
| 183 |
+
|
| 184 |
+
# check if the index is valid
|
| 185 |
+
is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
|
| 186 |
+
|
| 187 |
+
# check if the resulting tensor is within bounds
|
| 188 |
+
if len(resulting_tensor_dims) > 4:
|
| 189 |
+
return F(), counter
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
|
| 193 |
+
BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
|
| 194 |
+
*nat_constraints,
|
| 195 |
+
is_valid_index]
|
| 196 |
+
return Conj(constraints), counter
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@register_transformation_rule(BinConstraintT)
|
| 200 |
+
def generate_binconstraint_t(constraint, counter):
|
| 201 |
+
"""
|
| 202 |
+
Transform binary constraints for tensors
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# precision constraints
|
| 206 |
+
if constraint.op == op_precision:
|
| 207 |
+
if constraint.lhs == Dyn:
|
| 208 |
+
return T(), counter
|
| 209 |
+
elif isinstance(constraint.lhs, TensorType):
|
| 210 |
+
is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
|
| 211 |
+
if is_fully_static:
|
| 212 |
+
return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
|
| 213 |
+
else:
|
| 214 |
+
new_dims = []
|
| 215 |
+
|
| 216 |
+
for _ in range(len(constraint.lhs.__args__)):
|
| 217 |
+
dim, counter = gen_dvar(counter)
|
| 218 |
+
new_dims.append(dim)
|
| 219 |
+
|
| 220 |
+
new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
|
| 221 |
+
new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
|
| 222 |
+
[BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
|
| 223 |
+
[BinConstraintD(1, new_dim, op_leq) for
|
| 224 |
+
new_dim in new_dims]
|
| 225 |
+
return Conj(new_dim_constraints), counter
|
| 226 |
+
|
| 227 |
+
# matching
|
| 228 |
+
elif constraint.op == op_matching:
|
| 229 |
+
assert isinstance(constraint.rhs, TensorType)
|
| 230 |
+
d1 = constraint.rhs.__args__[0]
|
| 231 |
+
d2 = constraint.rhs.__args__[1]
|
| 232 |
+
d3 = constraint.rhs.__args__[2]
|
| 233 |
+
d4 = constraint.rhs.__args__[3]
|
| 234 |
+
|
| 235 |
+
conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
|
| 236 |
+
BinConstraintD(d1, Dyn, op_eq),
|
| 237 |
+
BinConstraintD(d2, Dyn, op_eq),
|
| 238 |
+
BinConstraintD(d3, Dyn, op_eq),
|
| 239 |
+
BinConstraintD(d4, Dyn, op_eq)]
|
| 240 |
+
return Disj([Conj(conj),
|
| 241 |
+
BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
|
| 242 |
+
|
| 243 |
+
elif constraint.op == op_consistency:
|
| 244 |
+
c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
|
| 245 |
+
[c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
|
| 246 |
+
|
| 247 |
+
return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
|
| 248 |
+
|
| 249 |
+
elif constraint.op == op_leq:
|
| 250 |
+
assert isinstance(constraint.rhs, int)
|
| 251 |
+
disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
|
| 252 |
+
for i in range(1, constraint.rhs + 1):
|
| 253 |
+
dims = []
|
| 254 |
+
for j in range(1, i + 1):
|
| 255 |
+
dim_var, counter = gen_dvar(counter)
|
| 256 |
+
dims.append(dim_var)
|
| 257 |
+
disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
|
| 258 |
+
return Disj(disj), counter
|
| 259 |
+
else:
|
| 260 |
+
return constraint, counter
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@register_transformation_rule(BinConstraintD)
|
| 264 |
+
def generate_binconstraint_d(constraint, counter):
|
| 265 |
+
"""
|
| 266 |
+
Transform binary constraints for dimensions
|
| 267 |
+
"""
|
| 268 |
+
if constraint.op == op_precision:
|
| 269 |
+
if isinstance(constraint.lhs, int):
|
| 270 |
+
return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
|
| 271 |
+
elif constraint.lhs == Dyn:
|
| 272 |
+
return T(), counter
|
| 273 |
+
|
| 274 |
+
elif constraint.op == op_consistency:
|
| 275 |
+
return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
|
| 276 |
+
BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
|
| 277 |
+
|
| 278 |
+
else:
|
| 279 |
+
return constraint, counter
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@register_transformation_rule(Conj)
|
| 283 |
+
def generate_conj(constraint, counter):
|
| 284 |
+
"""
|
| 285 |
+
Transform conjunctions
|
| 286 |
+
"""
|
| 287 |
+
new = []
|
| 288 |
+
for c in constraint.conjucts:
|
| 289 |
+
new_c, counter = transform_constraint(c, counter)
|
| 290 |
+
new.append(new_c)
|
| 291 |
+
return Conj(new), counter
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@register_transformation_rule(Disj)
|
| 295 |
+
def generate_disj(constraint, counter):
|
| 296 |
+
"""
|
| 297 |
+
Transform disjunctions
|
| 298 |
+
"""
|
| 299 |
+
new = []
|
| 300 |
+
for c in constraint.disjuncts:
|
| 301 |
+
new_c, counter = transform_constraint(c, counter)
|
| 302 |
+
new.append(new_c)
|
| 303 |
+
return Disj(new), counter
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@register_transformation_rule(TGreatestUpperBound)
|
| 307 |
+
def generate_gub(constraint, counter):
|
| 308 |
+
"""
|
| 309 |
+
Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
|
| 310 |
+
on dimensions
|
| 311 |
+
"""
|
| 312 |
+
c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
|
| 313 |
+
BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
|
| 314 |
+
|
| 315 |
+
[c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
|
| 316 |
+
|
| 317 |
+
return Disj([c1, c2, c3, c4, c5]), counter
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@register_transformation_rule(DGreatestUpperBound)
|
| 321 |
+
def generate_d_gub(constraint, counter):
|
| 322 |
+
"""
|
| 323 |
+
Transform greatest upper bound for dimensions into equality constraints
|
| 324 |
+
"""
|
| 325 |
+
c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
|
| 326 |
+
c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
|
| 327 |
+
c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
|
| 328 |
+
return Disj([c1, c2, c3]), counter
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
@register_transformation_rule(CalcConv)
|
| 332 |
+
def generate_calc_conv(constraint, counter):
|
| 333 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 334 |
+
conv_result = TensorType([d[0], d[1], d[2], d[3]])
|
| 335 |
+
|
| 336 |
+
# the convolution result is a tensor of size 4
|
| 337 |
+
c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
|
| 338 |
+
|
| 339 |
+
# the second dimension of the output is equal to the output channels
|
| 340 |
+
c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
|
| 341 |
+
|
| 342 |
+
# the input corresponds to the output in the first dimension of the convolution
|
| 343 |
+
c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
|
| 344 |
+
|
| 345 |
+
c4, c5 = calc_last_two_dims(constraint, d)
|
| 346 |
+
|
| 347 |
+
leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
|
| 348 |
+
BinConstraintD(0, d[1], op_leq),
|
| 349 |
+
BinConstraintD(0, d[2], op_leq),
|
| 350 |
+
BinConstraintD(0, d[3], op_leq)])
|
| 351 |
+
|
| 352 |
+
return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
@register_transformation_rule(CalcMaxPool)
|
| 356 |
+
def generate_calc_maxpool(constraint, counter):
|
| 357 |
+
"""
|
| 358 |
+
Transform maxpool constraints
|
| 359 |
+
"""
|
| 360 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 361 |
+
maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
|
| 362 |
+
|
| 363 |
+
# the maxpool result is a tensor of size 4
|
| 364 |
+
c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
|
| 365 |
+
|
| 366 |
+
# the input corresponds to the output in the first and second dimension of maxpool
|
| 367 |
+
c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
|
| 368 |
+
c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
|
| 369 |
+
c4, c5 = calc_last_two_dims(constraint, d)
|
| 370 |
+
|
| 371 |
+
leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
|
| 372 |
+
BinConstraintD(0, d[1], op_leq),
|
| 373 |
+
BinConstraintD(0, d[2], op_leq),
|
| 374 |
+
BinConstraintD(0, d[3], op_leq)])
|
| 375 |
+
|
| 376 |
+
return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@register_transformation_rule(CalcProduct)
|
| 380 |
+
def generate_calc_product(constraint, counter):
|
| 381 |
+
"""
|
| 382 |
+
Transform flatten constraints
|
| 383 |
+
"""
|
| 384 |
+
start = constraint.start
|
| 385 |
+
end = constraint.end
|
| 386 |
+
dims = constraint.dims_to_flatten
|
| 387 |
+
flattened = constraint.flattened
|
| 388 |
+
n = len(constraint.dims_to_flatten)
|
| 389 |
+
|
| 390 |
+
# this will be evaluated right here
|
| 391 |
+
boundary_check = (0 <= start and start < end and end <= n)
|
| 392 |
+
|
| 393 |
+
c_boundary = T() if boundary_check else F()
|
| 394 |
+
|
| 395 |
+
lhs = dims[0:start]
|
| 396 |
+
rhs = dims[end:]
|
| 397 |
+
mid = dims[start:end]
|
| 398 |
+
|
| 399 |
+
all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
|
| 400 |
+
|
| 401 |
+
all_constraints = []
|
| 402 |
+
|
| 403 |
+
for p in all_possibilities:
|
| 404 |
+
p = list(p)
|
| 405 |
+
# this tells us there is a dynamic variable
|
| 406 |
+
contains_dyn = not all(constraint.op == op_neq for constraint in p)
|
| 407 |
+
if contains_dyn:
|
| 408 |
+
mid_var = [Dyn]
|
| 409 |
+
total_constraints = lhs + mid_var + rhs
|
| 410 |
+
if len(total_constraints) > 4:
|
| 411 |
+
all_constraints.append(F())
|
| 412 |
+
else:
|
| 413 |
+
all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
|
| 414 |
+
else:
|
| 415 |
+
new_var, counter = gen_dvar(counter)
|
| 416 |
+
mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
|
| 417 |
+
mid_var = [new_var]
|
| 418 |
+
total_constraints = lhs + mid_var + rhs
|
| 419 |
+
if len(total_constraints) > 4:
|
| 420 |
+
all_constraints.append(F())
|
| 421 |
+
else:
|
| 422 |
+
all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
|
| 423 |
+
|
| 424 |
+
return Conj([Disj(all_constraints), c_boundary]), counter
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@register_transformation_rule(CanReshape)
|
| 428 |
+
def generate_reshape(constraint, counter):
|
| 429 |
+
"""
|
| 430 |
+
Transform reshape constraints
|
| 431 |
+
"""
|
| 432 |
+
d, counter = gen_tensor_dims(4, counter)
|
| 433 |
+
|
| 434 |
+
d1 = d[0]
|
| 435 |
+
d2 = d[1]
|
| 436 |
+
d3 = d[2]
|
| 437 |
+
d4 = d[3]
|
| 438 |
+
|
| 439 |
+
target = constraint.target.__args__
|
| 440 |
+
|
| 441 |
+
is_fully_static = all(d != Dyn for d in target)
|
| 442 |
+
|
| 443 |
+
# dynamic tensor
|
| 444 |
+
c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
|
| 445 |
+
c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
|
| 446 |
+
c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
|
| 447 |
+
c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
|
| 448 |
+
c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
|
| 449 |
+
|
| 450 |
+
d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
|
| 451 |
+
d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
|
| 452 |
+
|
| 453 |
+
d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
|
| 454 |
+
d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
|
| 455 |
+
|
| 456 |
+
d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
|
| 457 |
+
d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
|
| 458 |
+
|
| 459 |
+
d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
|
| 460 |
+
d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
|
| 461 |
+
|
| 462 |
+
nat_d1 = BinConstraintD(0, d1, op_leq)
|
| 463 |
+
nat_d2 = BinConstraintD(0, d2, op_leq)
|
| 464 |
+
nat_d3 = BinConstraintD(0, d3, op_leq)
|
| 465 |
+
nat_d4 = BinConstraintD(0, d4, op_leq)
|
| 466 |
+
|
| 467 |
+
if is_fully_static:
|
| 468 |
+
# size 1 tensor
|
| 469 |
+
c3_tensor1 = Disj([d1_eq_dyn,
|
| 470 |
+
(Conj([d1_neq_dyn,
|
| 471 |
+
BinConstraintD(d1, Prod(target), op_eq)]))])
|
| 472 |
+
all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
|
| 473 |
+
|
| 474 |
+
# size 2 tensor
|
| 475 |
+
all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
|
| 476 |
+
|
| 477 |
+
# size 3 tensor
|
| 478 |
+
all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
|
| 479 |
+
|
| 480 |
+
# size 4 tensor
|
| 481 |
+
all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
|
| 482 |
+
|
| 483 |
+
return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
|
| 484 |
+
nat_d1, nat_d2, nat_d3, nat_d4]), counter
|
| 485 |
+
|
| 486 |
+
# then there must be exactly one occurrence of dyn
|
| 487 |
+
else:
|
| 488 |
+
new_target = []
|
| 489 |
+
|
| 490 |
+
for n in target:
|
| 491 |
+
if n != Dyn:
|
| 492 |
+
new_target.append(n)
|
| 493 |
+
|
| 494 |
+
# tensor 1
|
| 495 |
+
c3_tensor1 = Disj([d1_eq_dyn,
|
| 496 |
+
(Conj([d1_neq_dyn,
|
| 497 |
+
is_dim_div_by_target(new_target, d1)]))])
|
| 498 |
+
all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
|
| 499 |
+
|
| 500 |
+
# tensor 2
|
| 501 |
+
c21 = Disj([d1_eq_dyn, d2_eq_dyn])
|
| 502 |
+
c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
|
| 503 |
+
all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
|
| 504 |
+
|
| 505 |
+
# tensor 3
|
| 506 |
+
c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
|
| 507 |
+
c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
|
| 508 |
+
all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
|
| 509 |
+
|
| 510 |
+
# tensor 4
|
| 511 |
+
c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
|
| 512 |
+
c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
|
| 513 |
+
all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
|
| 514 |
+
|
| 515 |
+
return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
|
| 516 |
+
nat_d1, nat_d2, nat_d3, nat_d4]), counter
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
@register_transformation_rule(ApplyBroadcasting)
|
| 520 |
+
def generate_broadcasting(constraint, counter):
|
| 521 |
+
"""
|
| 522 |
+
Transform broadcasting constraints
|
| 523 |
+
"""
|
| 524 |
+
e11, e12 = constraint.res1, constraint.res2
|
| 525 |
+
e1, e2 = constraint.input1, constraint.input2
|
| 526 |
+
|
| 527 |
+
e1_dyn = BinConstraintT(e1, Dyn, op_eq)
|
| 528 |
+
e2_dyn = BinConstraintT(e2, Dyn, op_eq)
|
| 529 |
+
|
| 530 |
+
# Introduce dimensions
|
| 531 |
+
e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
|
| 532 |
+
e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
|
| 533 |
+
|
| 534 |
+
# dyn possibility
|
| 535 |
+
e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
|
| 536 |
+
e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
|
| 537 |
+
|
| 538 |
+
# tensor possibility
|
| 539 |
+
# generate dimensions to create tensors of size 1
|
| 540 |
+
final_tensor_1_constraint, _, _, nat_dims_1, counter = \
|
| 541 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
|
| 542 |
+
|
| 543 |
+
# generate dimensions to create tensors of size 2
|
| 544 |
+
final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
|
| 545 |
+
final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
|
| 546 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
|
| 547 |
+
|
| 548 |
+
# generate dimensions to create tensors of size 3
|
| 549 |
+
final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
|
| 550 |
+
final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
|
| 551 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
|
| 552 |
+
|
| 553 |
+
# generate dimensions to create tensors of size 4
|
| 554 |
+
final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
|
| 555 |
+
final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
|
| 556 |
+
gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
|
| 557 |
+
|
| 558 |
+
final_result = Disj([
|
| 559 |
+
e1_dyn_constraint,
|
| 560 |
+
e2_dyn_constraint,
|
| 561 |
+
final_tensor_1_constraint,
|
| 562 |
+
final_tensor_2_constraint_no_padding,
|
| 563 |
+
final_tensor_2_constraint_padding_arg1,
|
| 564 |
+
final_tensor_2_constraint_padding_arg2,
|
| 565 |
+
final_tensor_3_constraint_no_padding,
|
| 566 |
+
final_tensor_3_constraint_padding_arg1,
|
| 567 |
+
final_tensor_3_constraint_padding_arg2,
|
| 568 |
+
final_tensor_4_constraint_no_padding,
|
| 569 |
+
final_tensor_4_constraint_padding_arg1,
|
| 570 |
+
final_tensor_4_constraint_padding_arg2
|
| 571 |
+
])
|
| 572 |
+
|
| 573 |
+
return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def transform_constraint(constraint: Constraint, counter: int):
|
| 577 |
+
"""
|
| 578 |
+
Transforms a constraint into a simpler constraint.
|
| 579 |
+
Ex: precision and consistency are transformed to equality
|
| 580 |
+
Args:
|
| 581 |
+
constraint: constraint to be transformed
|
| 582 |
+
counter: for variable tracking
|
| 583 |
+
|
| 584 |
+
Returns: Constraint
|
| 585 |
+
|
| 586 |
+
"""
|
| 587 |
+
if type(constraint) in _TRANSFORMATION_RULES:
|
| 588 |
+
return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
|
| 589 |
+
|
| 590 |
+
else:
|
| 591 |
+
return constraint, counter
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def calc_last_two_dims(constraint, d: List[DVar]):
|
| 597 |
+
"""
|
| 598 |
+
Generates constraints for the last two dimensions of a convolution or a maxpool output
|
| 599 |
+
Args:
|
| 600 |
+
constraint: CalcConv or CalcMaxPool
|
| 601 |
+
d: The list of output dimensions
|
| 602 |
+
|
| 603 |
+
Returns: Constraints for calculating the last two dimensions of the output
|
| 604 |
+
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
assert isinstance(constraint, (CalcConv, CalcMaxPool))
|
| 608 |
+
|
| 609 |
+
b3 = constraint.matching_constraint[2]
|
| 610 |
+
b4 = constraint.matching_constraint[3]
|
| 611 |
+
|
| 612 |
+
b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
|
| 613 |
+
b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
|
| 614 |
+
|
| 615 |
+
d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
|
| 616 |
+
d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
|
| 617 |
+
|
| 618 |
+
# transform parameters into tuples incase they are not already
|
| 619 |
+
padding = (constraint.padding, constraint.padding) \
|
| 620 |
+
if isinstance(constraint.padding, int) else constraint.padding
|
| 621 |
+
kernel = (constraint.kernel, constraint.kernel) \
|
| 622 |
+
if isinstance(constraint.kernel, int) else constraint.kernel
|
| 623 |
+
stride = (constraint.stride, constraint.stride) \
|
| 624 |
+
if isinstance(constraint.stride, int) else constraint.stride
|
| 625 |
+
dilation = (constraint.dilation, constraint.dilation) \
|
| 626 |
+
if isinstance(constraint.dilation, int) else constraint.dilation
|
| 627 |
+
|
| 628 |
+
f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
|
| 629 |
+
f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
|
| 630 |
+
f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
|
| 631 |
+
f4 = BinConstraintD(f3, 1, op_add)
|
| 632 |
+
|
| 633 |
+
c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
|
| 634 |
+
|
| 635 |
+
f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
|
| 636 |
+
f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
|
| 637 |
+
f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
|
| 638 |
+
f44 = BinConstraintD(f33, 1, op_add)
|
| 639 |
+
|
| 640 |
+
c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
|
| 641 |
+
|
| 642 |
+
return c4, c5
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
|
| 646 |
+
"""
|
| 647 |
+
Generate all possibilities of being equal or not equal to dyn for my_list
|
| 648 |
+
Args:
|
| 649 |
+
my_list: List of tensor dimensions
|
| 650 |
+
|
| 651 |
+
Returns: A list of a list of constraints. Each list of constraints corresponds to
|
| 652 |
+
one possibility about the values of the dimension variables
|
| 653 |
+
"""
|
| 654 |
+
# generate all possibilities of being equal or not equal to dyn for my_list
|
| 655 |
+
eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
|
| 656 |
+
neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
|
| 657 |
+
d_possibilities = []
|
| 658 |
+
|
| 659 |
+
for i in zip(eq_possibilities, neq_possibilities):
|
| 660 |
+
d_possibilities.append(list(i))
|
| 661 |
+
all_possibilities = list(itertools.product(*d_possibilities))
|
| 662 |
+
return all_possibilities
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def is_target_div_by_dim(target: List[int], dim: List[DVar]):
|
| 666 |
+
"""
|
| 667 |
+
Generate constraints to check if the target dimensions are divisible by the input dimensions
|
| 668 |
+
Args:
|
| 669 |
+
target: Target dimensions
|
| 670 |
+
dim: Input dimensions
|
| 671 |
+
|
| 672 |
+
Returns: Constraints to check divisibility
|
| 673 |
+
|
| 674 |
+
"""
|
| 675 |
+
return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def is_dim_div_by_target(target: List[int], dim: List[DVar]):
|
| 679 |
+
"""
|
| 680 |
+
Generate constraints to check if the input dimensions is divisible by the target dimensions
|
| 681 |
+
Args:
|
| 682 |
+
target: Target dimensions
|
| 683 |
+
dim: Input dimensions
|
| 684 |
+
|
| 685 |
+
Returns: Constraints to check divisibility
|
| 686 |
+
|
| 687 |
+
"""
|
| 688 |
+
return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def gen_all_reshape_possibilities(list_of_dims, target):
|
| 692 |
+
"""
|
| 693 |
+
Consider all possibilities what the input dimensions could be (number or dynamic)
|
| 694 |
+
Then generate the appropriate constraints using multiplication or mod depending on the possibility
|
| 695 |
+
The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
|
| 696 |
+
for the input. Target is fixed because at most one dimension could be dyn.
|
| 697 |
+
We have different cases for this.
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
list_of_dims: The input list of dimensions
|
| 701 |
+
target: The tensor we want to reshape to
|
| 702 |
+
|
| 703 |
+
Returns: A disjunction of transformed reshape constraints
|
| 704 |
+
|
| 705 |
+
"""
|
| 706 |
+
all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
|
| 707 |
+
|
| 708 |
+
all_constraints = []
|
| 709 |
+
|
| 710 |
+
for p in all_possibilities:
|
| 711 |
+
to_multiply = []
|
| 712 |
+
|
| 713 |
+
p = list(p)
|
| 714 |
+
|
| 715 |
+
for constraint in p:
|
| 716 |
+
assert isinstance(constraint, BinConstraintD)
|
| 717 |
+
if constraint.op == op_neq:
|
| 718 |
+
to_multiply.append(constraint.lhs)
|
| 719 |
+
|
| 720 |
+
if not to_multiply:
|
| 721 |
+
all_constraints.append(Conj(p))
|
| 722 |
+
|
| 723 |
+
elif len(to_multiply) < len(list_of_dims):
|
| 724 |
+
all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
|
| 725 |
+
else:
|
| 726 |
+
all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
|
| 727 |
+
Prod(target), op_eq)]))
|
| 728 |
+
|
| 729 |
+
return Disj(all_constraints)
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
|
| 733 |
+
"""
|
| 734 |
+
Apply broadcasting to the 'index' dimension of tensor_input1.
|
| 735 |
+
Args:
|
| 736 |
+
tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
|
| 737 |
+
tensor_input2: represents the second input
|
| 738 |
+
res1: broadcasted result 1
|
| 739 |
+
res2: broadcasted result 2
|
| 740 |
+
index: the index to broadcast
|
| 741 |
+
padding: If padding was used, then tensor_input1[index] does not exist
|
| 742 |
+
|
| 743 |
+
Returns:
|
| 744 |
+
|
| 745 |
+
"""
|
| 746 |
+
if tensor_input1[index] is None:
|
| 747 |
+
assert padding
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
if not padding:
|
| 751 |
+
# then the inputs are the same length so they all have dimensions at "index"
|
| 752 |
+
return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
|
| 753 |
+
BinConstraintD(res1[index], res2[index], op_eq),
|
| 754 |
+
BinConstraintD(res2[index], tensor_input2[index], op_eq)])
|
| 755 |
+
|
| 756 |
+
else:
|
| 757 |
+
# we don't set the input dimension to 1, since it doesn't exist.
|
| 758 |
+
return Conj([BinConstraintD(res1[index], res2[index], op_eq),
|
| 759 |
+
BinConstraintD(res2[index], tensor_input2[index], op_eq)])
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def apply_padding(e1_var: TVar,
|
| 763 |
+
e11: BinConstraintT,
|
| 764 |
+
e2: BinConstraintT,
|
| 765 |
+
e12: BinConstraintT,
|
| 766 |
+
d2: List[DVar],
|
| 767 |
+
d11: List[DVar],
|
| 768 |
+
d12: List[DVar],
|
| 769 |
+
counter: int):
|
| 770 |
+
"""
|
| 771 |
+
We are considering the possibility where one input has less dimensions than
|
| 772 |
+
another input, so we apply padding to the broadcasted results
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
e1_var: Variable representing the first input where padding will be
|
| 776 |
+
e11: constraint of the form e11 = Tensortype[d1, ..., dn]
|
| 777 |
+
e2: constraint of the form e2 = Tensortype[d1, ..., dn]
|
| 778 |
+
e12: constraint of the form e11 = Tensortype[d1, ..., dn]
|
| 779 |
+
d2: Tensor variables for the second input
|
| 780 |
+
d11: Tensor variables for the broadcasted first input
|
| 781 |
+
d12: Tensor variables for the broadcasted second input
|
| 782 |
+
counter: variable tracking
|
| 783 |
+
|
| 784 |
+
Returns: A new constraint whose goal is to apply padding to the broadcasted result
|
| 785 |
+
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
res = []
|
| 789 |
+
|
| 790 |
+
# pad the shorter input with None so we can pass it to the broadcasting helper function
|
| 791 |
+
for i in range(1, len(d2)):
|
| 792 |
+
|
| 793 |
+
d1, counter = gen_tensor_dims(i, counter)
|
| 794 |
+
|
| 795 |
+
nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
|
| 796 |
+
|
| 797 |
+
e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
|
| 798 |
+
|
| 799 |
+
simulate_padding = [None] * (len(d2) - i)
|
| 800 |
+
|
| 801 |
+
assert len(simulate_padding + d1) == len(d2)
|
| 802 |
+
|
| 803 |
+
broadcast_padding = []
|
| 804 |
+
|
| 805 |
+
# for every padding size, we also consider broadcasting
|
| 806 |
+
for j in range(len(d2) - i):
|
| 807 |
+
broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
|
| 808 |
+
|
| 809 |
+
# we consider the possibilities for broadcasting for every dimension. Since we already
|
| 810 |
+
# padded d1, we do not consider it while broadcasting
|
| 811 |
+
all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
|
| 812 |
+
d2[(len(d2) - i):],
|
| 813 |
+
d11[(len(d2) - i):],
|
| 814 |
+
d12[(len(d2) - i):])
|
| 815 |
+
# combine all constraints into a conjunction
|
| 816 |
+
c = Conj([e1, e11, e2, e12,
|
| 817 |
+
*broadcast_padding,
|
| 818 |
+
all_broadcasting_possibilities,
|
| 819 |
+
*nat_constraints
|
| 820 |
+
])
|
| 821 |
+
res.append(c)
|
| 822 |
+
|
| 823 |
+
return Disj(res), counter
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def no_broadcast_dim_with_index(d1: List[DVar],
|
| 827 |
+
d2: List[DVar],
|
| 828 |
+
d3: List[DVar],
|
| 829 |
+
d4: List[DVar],
|
| 830 |
+
i: int):
|
| 831 |
+
"""
|
| 832 |
+
Args:
|
| 833 |
+
d1: input 1
|
| 834 |
+
d2: input 2
|
| 835 |
+
d3: simulated broadcasting for input 1
|
| 836 |
+
d4: simulated broadcasting for input 2
|
| 837 |
+
i: the rank of the resulting tensor addition
|
| 838 |
+
|
| 839 |
+
Returns: Constraints for when no broadcasting occurs
|
| 840 |
+
"""
|
| 841 |
+
return Conj([
|
| 842 |
+
Disj([
|
| 843 |
+
Conj([BinConstraintD(d1[i], 1, op_eq),
|
| 844 |
+
BinConstraintD(d2[i], 1, op_eq)]),
|
| 845 |
+
|
| 846 |
+
Conj([BinConstraintD(d1[i], 1, op_neq),
|
| 847 |
+
BinConstraintD(d2[i], 1, op_neq)])]),
|
| 848 |
+
|
| 849 |
+
BinConstraintD(d1[i], d3[i], op_eq),
|
| 850 |
+
BinConstraintD(d2[i], d4[i], op_eq)])
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
|
| 855 |
+
"""
|
| 856 |
+
Generate lists of DVar to represent tensor dimensions
|
| 857 |
+
Args:
|
| 858 |
+
num_tensors: the required number of tensors
|
| 859 |
+
dim_size: the number of dimensions for each tensor
|
| 860 |
+
counter: variable tracking
|
| 861 |
+
|
| 862 |
+
Returns: A list of a list of tensor dimensions
|
| 863 |
+
|
| 864 |
+
"""
|
| 865 |
+
res = []
|
| 866 |
+
|
| 867 |
+
for _ in range(num_tensors):
|
| 868 |
+
dims, counter = gen_tensor_dims(dim_size, counter)
|
| 869 |
+
res.append(dims)
|
| 870 |
+
|
| 871 |
+
return res, counter
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def create_equality_constraints_for_broadcasting(e1: TVar,
|
| 875 |
+
e2: TVar,
|
| 876 |
+
e11: TVar,
|
| 877 |
+
e12: TVar,
|
| 878 |
+
d1: List[DVar],
|
| 879 |
+
d2: List[DVar],
|
| 880 |
+
d11: List[DVar],
|
| 881 |
+
d12: List[DVar]):
|
| 882 |
+
"""
|
| 883 |
+
Create equality constraints for when no broadcasting occurs
|
| 884 |
+
Args:
|
| 885 |
+
e1: Input 1
|
| 886 |
+
e2: Input 2
|
| 887 |
+
e11: Broadcasted input 1
|
| 888 |
+
e12: Broadcasted input 2
|
| 889 |
+
d1: Variables that store dimensions for e1
|
| 890 |
+
d2: Variables that store dimensions for e2
|
| 891 |
+
d11: Variables that store dimensions for e11
|
| 892 |
+
d12: Variables that store dimensions for e22
|
| 893 |
+
|
| 894 |
+
Returns: Four equality constraints
|
| 895 |
+
|
| 896 |
+
"""
|
| 897 |
+
|
| 898 |
+
e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
|
| 899 |
+
e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
|
| 900 |
+
e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
|
| 901 |
+
e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
|
| 902 |
+
return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def gen_consistency_constraints(constraint: Constraint, counter: int):
|
| 906 |
+
"""
|
| 907 |
+
Args:
|
| 908 |
+
constraint: Consistency constraint on tensors
|
| 909 |
+
counter: for variable tracking
|
| 910 |
+
|
| 911 |
+
Returns: Equality and consistency constraints on dimensions
|
| 912 |
+
|
| 913 |
+
"""
|
| 914 |
+
|
| 915 |
+
all_constraints = []
|
| 916 |
+
|
| 917 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 918 |
+
new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
|
| 919 |
+
new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
|
| 920 |
+
|
| 921 |
+
nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
|
| 922 |
+
|
| 923 |
+
c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
|
| 924 |
+
BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
|
| 925 |
+
[BinConstraintD(d1, d2, op_consistency) for
|
| 926 |
+
d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
|
| 927 |
+
|
| 928 |
+
all_constraints.append(c_tensor_i)
|
| 929 |
+
|
| 930 |
+
return all_constraints, counter
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
|
| 934 |
+
"""
|
| 935 |
+
Args:
|
| 936 |
+
constraint: Greatest upper bound on tensors
|
| 937 |
+
counter: variable tracking
|
| 938 |
+
|
| 939 |
+
Returns: A set of equality constraints and DGreatestUpperBound constraints
|
| 940 |
+
|
| 941 |
+
"""
|
| 942 |
+
|
| 943 |
+
all_constraints = []
|
| 944 |
+
|
| 945 |
+
for i in range(1, MAX_TENSOR_RANK + 1):
|
| 946 |
+
c = []
|
| 947 |
+
dims1, counter = gen_tensor_dims(i, counter)
|
| 948 |
+
c1tensor = TensorType(dims1)
|
| 949 |
+
|
| 950 |
+
dims2, counter = gen_tensor_dims(i, counter)
|
| 951 |
+
c2tensor = TensorType(dims2)
|
| 952 |
+
|
| 953 |
+
dims3, counter = gen_tensor_dims(i, counter)
|
| 954 |
+
c3tensor = TensorType(dims3)
|
| 955 |
+
|
| 956 |
+
c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
|
| 957 |
+
BinConstraintT(constraint.rhs2, c2tensor, op_eq),
|
| 958 |
+
BinConstraintT(constraint.res, c3tensor, op_eq)] + \
|
| 959 |
+
gen_nat_constraints(dims1 + dims2 + dims3)
|
| 960 |
+
|
| 961 |
+
assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
|
| 962 |
+
for i in range(len(c3tensor.__args__)):
|
| 963 |
+
c.append(DGreatestUpperBound(c3tensor.__args__[i],
|
| 964 |
+
c1tensor.__args__[i],
|
| 965 |
+
c2tensor.__args__[i]))
|
| 966 |
+
|
| 967 |
+
all_constraints.append(Conj(c))
|
| 968 |
+
return all_constraints, counter
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
|
| 972 |
+
"""
|
| 973 |
+
Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
|
| 974 |
+
We look at all combinations for all dimensions in d1 and d2
|
| 975 |
+
Args:
|
| 976 |
+
d1: input1 dimensions
|
| 977 |
+
d2: input2 dimensions
|
| 978 |
+
d11: broadcasted input1 dimensions
|
| 979 |
+
d12: broadcasted input2 dimensions
|
| 980 |
+
|
| 981 |
+
Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
|
| 982 |
+
|
| 983 |
+
"""
|
| 984 |
+
|
| 985 |
+
size = len(d1)
|
| 986 |
+
|
| 987 |
+
res2 = []
|
| 988 |
+
|
| 989 |
+
for i in range(size):
|
| 990 |
+
t1 = broadcast_dim(d1, d2, d11, d12, i)
|
| 991 |
+
t2 = broadcast_dim(d2, d1, d12, d11, i)
|
| 992 |
+
t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
|
| 993 |
+
|
| 994 |
+
res2.append(Disj([t1, t2, t3]))
|
| 995 |
+
|
| 996 |
+
return Conj(res2)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
|
| 1000 |
+
"""
|
| 1001 |
+
Simulates broadcasting on e1 and e2 and returns the results
|
| 1002 |
+
respectively in e11 and e12. Because of gradual types,
|
| 1003 |
+
e1 and e2 may not be equal. Similarly, e11 and e12 may not
|
| 1004 |
+
be equal. e11 and e12 should be guaranteed to be consistent
|
| 1005 |
+
as they represent the shapes of the tensors to be added after
|
| 1006 |
+
broadcasting.
|
| 1007 |
+
Args:
|
| 1008 |
+
e1: TVar representing the type of input 1
|
| 1009 |
+
e2: TVar representing the type of input 2
|
| 1010 |
+
e11: TVar representing the representing broadcasted input 1
|
| 1011 |
+
e12: TVar representing the representing broadcasted input 2
|
| 1012 |
+
i: The rank of the resulting type of addition
|
| 1013 |
+
counter: for variable tracking
|
| 1014 |
+
|
| 1015 |
+
Returns: Simplified broadcasting constraints
|
| 1016 |
+
|
| 1017 |
+
"""
|
| 1018 |
+
dims, counter = gen_lists_of_dims(4, i, counter)
|
| 1019 |
+
[d1, d2, d3, d4] = dims
|
| 1020 |
+
nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
|
| 1021 |
+
|
| 1022 |
+
initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
|
| 1023 |
+
d1, d2, d3, d4)
|
| 1024 |
+
|
| 1025 |
+
[e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
|
| 1026 |
+
|
| 1027 |
+
# without padding, broadcast all possibilities for tensors of size i
|
| 1028 |
+
final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
|
| 1029 |
+
generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
|
| 1030 |
+
|
| 1031 |
+
# with padding, broadcast all possibilities for tensors of size i
|
| 1032 |
+
final_tensor_constraint_padding_arg1, counter = \
|
| 1033 |
+
apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
|
| 1034 |
+
|
| 1035 |
+
final_tensor_constraint_padding_arg2, counter = \
|
| 1036 |
+
apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
|
| 1037 |
+
|
| 1038 |
+
return final_tensor_constraint_no_padding, \
|
| 1039 |
+
final_tensor_constraint_padding_arg1, \
|
| 1040 |
+
final_tensor_constraint_padding_arg2, nat_dims_i, counter
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
op_add = '+'
|
| 2 |
+
op_sub = '-'
|
| 3 |
+
op_mul = '*'
|
| 4 |
+
op_div = '/'
|
| 5 |
+
op_eq = '='
|
| 6 |
+
op_neq = '!='
|
| 7 |
+
op_imp = '=>'
|
| 8 |
+
op_matching = '\u22b3' # (contains)
|
| 9 |
+
op_consistency = '~'
|
| 10 |
+
op_precision = '\u2291' # (square image of or equal to)
|
| 11 |
+
op_leq = '\u2264' # less-than or equal to
|
| 12 |
+
op_lt = '<'
|
| 13 |
+
op_gt = '>'
|
| 14 |
+
op_mod = '%'
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
|
| 3 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
|
| 4 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
|
| 5 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
|
| 6 |
+
from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
|
| 7 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
|
| 8 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
|
| 9 |
+
from torch.fx.tensor_type import TensorType, Dyn
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import z3 # type: ignore[import]
|
| 13 |
+
from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
|
| 14 |
+
HAS_Z3 = True
|
| 15 |
+
|
| 16 |
+
def transform_to_z3(constraint, counter, dimension_dict):
|
| 17 |
+
if isinstance(constraint, Conj):
|
| 18 |
+
conjuncts = []
|
| 19 |
+
for c in constraint.conjucts:
|
| 20 |
+
new_c, counter = transform_to_z3(c, counter, dimension_dict)
|
| 21 |
+
conjuncts.append(new_c)
|
| 22 |
+
return z3.And(conjuncts), counter
|
| 23 |
+
|
| 24 |
+
elif isinstance(constraint, Disj):
|
| 25 |
+
disjuncts = []
|
| 26 |
+
for c in constraint.disjuncts:
|
| 27 |
+
new_c, counter = transform_to_z3(c, counter, dimension_dict)
|
| 28 |
+
disjuncts.append(new_c)
|
| 29 |
+
return z3.Or(disjuncts), counter
|
| 30 |
+
|
| 31 |
+
elif isinstance(constraint, T):
|
| 32 |
+
return True, counter
|
| 33 |
+
|
| 34 |
+
elif isinstance(constraint, F):
|
| 35 |
+
return False, counter
|
| 36 |
+
|
| 37 |
+
elif isinstance(constraint, BinConstraintT):
|
| 38 |
+
if constraint.op == op_eq:
|
| 39 |
+
lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
|
| 40 |
+
rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
|
| 41 |
+
return (lhs == rhs), counter
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError('Method not yet implemented')
|
| 45 |
+
|
| 46 |
+
elif isinstance(constraint, BinConstraintD):
|
| 47 |
+
if constraint.op == op_eq:
|
| 48 |
+
|
| 49 |
+
if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
|
| 50 |
+
transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
|
| 51 |
+
transformed_lhs = z3.Bool(constraint.lhs.c)
|
| 52 |
+
return transformed_lhs == transformed_rhs, counter
|
| 53 |
+
|
| 54 |
+
elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
|
| 55 |
+
# with dimension transformations we consider the encoding
|
| 56 |
+
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
| 57 |
+
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
| 58 |
+
return lhs == rhs, counter
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
# then we have an algebraic expression which means that we disregard the
|
| 62 |
+
# first element of the encoding
|
| 63 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 64 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 65 |
+
return lhs == rhs, counter
|
| 66 |
+
|
| 67 |
+
# The assumption here is that the LHS and RHS must be dimensions
|
| 68 |
+
elif constraint.op == op_neq:
|
| 69 |
+
assert is_dim(constraint.lhs)
|
| 70 |
+
assert is_dim(constraint.rhs)
|
| 71 |
+
lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
|
| 72 |
+
rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
|
| 73 |
+
if constraint.rhs == Dyn or constraint.lhs == Dyn:
|
| 74 |
+
if constraint.rhs == Dyn:
|
| 75 |
+
return lhs.arg(0) == 1, counter
|
| 76 |
+
elif constraint.lhs == Dyn:
|
| 77 |
+
return rhs.arg(0) == 1, counter
|
| 78 |
+
|
| 79 |
+
# if one of the instances is a number
|
| 80 |
+
elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
|
| 81 |
+
if isinstance(constraint.lhs, int):
|
| 82 |
+
return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 83 |
+
|
| 84 |
+
elif isinstance(constraint.rhs, int):
|
| 85 |
+
return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
|
| 89 |
+
z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
|
| 90 |
+
z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
elif constraint.op == op_leq:
|
| 94 |
+
# if the dimensions are not dyn, this will come into effect
|
| 95 |
+
# there would have been another constraint specifying if a given dimension
|
| 96 |
+
# is dyn or not
|
| 97 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 98 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 99 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 100 |
+
return lhs <= rhs, counter
|
| 101 |
+
|
| 102 |
+
elif constraint.op == op_gt:
|
| 103 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 104 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 105 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 106 |
+
return lhs > rhs, counter
|
| 107 |
+
|
| 108 |
+
elif constraint.op == op_lt:
|
| 109 |
+
assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
|
| 110 |
+
lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
|
| 111 |
+
rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
|
| 112 |
+
return lhs < rhs, counter
|
| 113 |
+
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError('operation not yet implemented')
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
raise NotImplementedError('Operation not yet implemented')
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def transform_var(tensor, counter, dimension_dict):
|
| 122 |
+
"""
|
| 123 |
+
Transforms tensor variables to a format understood by z3
|
| 124 |
+
Args:
|
| 125 |
+
tensor: Tensor variable or a tensor type potentially with variable dimensions
|
| 126 |
+
Returns: Transformed variable to a z3 format
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
if isinstance(tensor, TensorType):
|
| 130 |
+
res = []
|
| 131 |
+
for t in tensor.__args__:
|
| 132 |
+
transformed, counter = transform_dimension(t, counter, dimension_dict)
|
| 133 |
+
res.append(transformed)
|
| 134 |
+
|
| 135 |
+
assert len(res) <= 4
|
| 136 |
+
if len(tensor.__args__) == 1:
|
| 137 |
+
return tensor_type.tensor1(res[0]), counter
|
| 138 |
+
elif len(tensor.__args__) == 2:
|
| 139 |
+
return tensor_type.tensor2(res[0], res[1]), counter
|
| 140 |
+
elif len(tensor.__args__) == 3:
|
| 141 |
+
return tensor_type.tensor3(res[0], res[1], res[2]), counter
|
| 142 |
+
elif len(tensor.__args__) == 4:
|
| 143 |
+
return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
|
| 144 |
+
|
| 145 |
+
elif tensor == Dyn:
|
| 146 |
+
return z3_dyn, counter
|
| 147 |
+
|
| 148 |
+
elif isinstance(tensor, TVar):
|
| 149 |
+
return z3.Const(tensor.tvar, tensor_type), counter
|
| 150 |
+
|
| 151 |
+
def transform_dimension(dimension, counter, dimension_dict):
|
| 152 |
+
"""
|
| 153 |
+
Takes a dimension variable or a number and transforms it to a tuple
|
| 154 |
+
according to our scheme
|
| 155 |
+
Args:
|
| 156 |
+
dimension: The dimension to be transformed
|
| 157 |
+
counter: variable tracking
|
| 158 |
+
|
| 159 |
+
Returns: tuple and the current counter
|
| 160 |
+
|
| 161 |
+
"""
|
| 162 |
+
if dimension == Dyn:
|
| 163 |
+
counter += 1
|
| 164 |
+
return D(0, z3.Int(counter)), counter
|
| 165 |
+
elif isinstance(dimension, int):
|
| 166 |
+
return D(1, dimension), counter
|
| 167 |
+
elif isinstance(dimension, DVar):
|
| 168 |
+
if dimension.c in dimension_dict:
|
| 169 |
+
return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
|
| 170 |
+
else:
|
| 171 |
+
counter += 1
|
| 172 |
+
dimension_dict[dimension.c] = counter
|
| 173 |
+
return D(z3.Int(counter), z3.Int(dimension.c)), counter
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def transform_algebraic_expression(expr, counter, dimension_dict):
|
| 177 |
+
"""
|
| 178 |
+
Transforms an algebraic expression to z3 format
|
| 179 |
+
Args:
|
| 180 |
+
expr: An expression is either a dimension variable or an algebraic-expression
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
Returns: the transformed expression
|
| 184 |
+
|
| 185 |
+
"""
|
| 186 |
+
assert is_algebraic_expression(expr) or is_dim(expr)
|
| 187 |
+
|
| 188 |
+
if is_dim(expr):
|
| 189 |
+
transformed, counter = transform_dimension(expr, counter, dimension_dict)
|
| 190 |
+
return transformed.arg(1), counter
|
| 191 |
+
|
| 192 |
+
elif isinstance(expr, Prod):
|
| 193 |
+
|
| 194 |
+
dims = []
|
| 195 |
+
for dim in expr.products:
|
| 196 |
+
assert is_dim(dim)
|
| 197 |
+
d, counter = transform_dimension(dim, counter, dimension_dict)
|
| 198 |
+
dims.append(d.arg(1))
|
| 199 |
+
return z3.Product(dims), counter
|
| 200 |
+
|
| 201 |
+
elif is_algebraic_expression(expr):
|
| 202 |
+
|
| 203 |
+
lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
|
| 204 |
+
rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
|
| 205 |
+
|
| 206 |
+
if expr.op == op_sub:
|
| 207 |
+
c = lhs - rhs
|
| 208 |
+
|
| 209 |
+
elif expr.op == op_add:
|
| 210 |
+
c = lhs + rhs
|
| 211 |
+
|
| 212 |
+
elif expr.op == op_div:
|
| 213 |
+
c = lhs / rhs
|
| 214 |
+
|
| 215 |
+
elif expr.op == op_mul:
|
| 216 |
+
c = lhs * rhs
|
| 217 |
+
|
| 218 |
+
elif expr.op == op_mod:
|
| 219 |
+
c = lhs % rhs
|
| 220 |
+
|
| 221 |
+
else:
|
| 222 |
+
raise NotImplementedError('operation not yet implemented')
|
| 223 |
+
|
| 224 |
+
return c, counter
|
| 225 |
+
|
| 226 |
+
else:
|
| 227 |
+
raise RuntimeError
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def transform_all_constraints(traced, counter=0):
|
| 231 |
+
"""
|
| 232 |
+
Given a trace, generates constraints and transforms them to z3 format
|
| 233 |
+
|
| 234 |
+
"""
|
| 235 |
+
dimension_dict = {} # type: ignore[var-annotated]
|
| 236 |
+
|
| 237 |
+
generator = ConstraintGenerator(traced)
|
| 238 |
+
new_constraints, counter = generator.generate_constraints(counter)
|
| 239 |
+
|
| 240 |
+
# print(new_constraints.conjucts[0])
|
| 241 |
+
# print(*new_constraints.conjucts, sep='\n')
|
| 242 |
+
|
| 243 |
+
# transform precision, matching, consistency till obtaining a fixed point
|
| 244 |
+
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
|
| 245 |
+
# print(new_constraints)
|
| 246 |
+
# print(new_constraints.conjucts)
|
| 247 |
+
# new_constraints.conjucts = new_constraints.conjucts[:-1]
|
| 248 |
+
# print(*new_constraints.conjucts, sep='\n')
|
| 249 |
+
|
| 250 |
+
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
|
| 251 |
+
# print(transformed)
|
| 252 |
+
return transformed
|
| 253 |
+
|
| 254 |
+
def iterate_till_fixed_point(constraints, counter):
|
| 255 |
+
"""
|
| 256 |
+
Transform constraints till reaching a fixed point
|
| 257 |
+
"""
|
| 258 |
+
old_c = None
|
| 259 |
+
while old_c != constraints:
|
| 260 |
+
old_c = constraints
|
| 261 |
+
constraints, counter = transform_constraint(constraints, counter)
|
| 262 |
+
return constraints, counter
|
| 263 |
+
|
| 264 |
+
def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
|
| 265 |
+
"""
|
| 266 |
+
Takes a node and a graph and generates two sets of constraints.
|
| 267 |
+
One set constraints the node's constraints and another set
|
| 268 |
+
constraints the negation of the node's constraints
|
| 269 |
+
Args:
|
| 270 |
+
tracer_root: the root for getting the module instances
|
| 271 |
+
graph: the graph so far in the tracing process
|
| 272 |
+
node: node that represents a conditional
|
| 273 |
+
counter: variable tracking
|
| 274 |
+
|
| 275 |
+
Returns: Two sets of constraints. One with a conjunction with the
|
| 276 |
+
the conditional constraint and the other with a conjunction with
|
| 277 |
+
its negation.
|
| 278 |
+
|
| 279 |
+
"""
|
| 280 |
+
dimension_dict = {} # type: ignore[var-annotated]
|
| 281 |
+
|
| 282 |
+
generator = ConstraintGenerator(tracer_root, graph)
|
| 283 |
+
new_constraints, counter = generator.generate_constraints(counter)
|
| 284 |
+
|
| 285 |
+
condition_constraint = new_constraints.conjucts[-1]
|
| 286 |
+
|
| 287 |
+
# we know the constraint is a conjunction where the last constraint is about the conditional
|
| 288 |
+
# so remove the last constraint
|
| 289 |
+
new_constraints.conjucts = new_constraints.conjucts[:-1]
|
| 290 |
+
|
| 291 |
+
# transform precision, matching, consistency till obtaining a fixed point
|
| 292 |
+
new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# since the function returns a list of one element, we get the first element
|
| 296 |
+
# we are only interested in the RHS in this case because the LHS just stores
|
| 297 |
+
# the result
|
| 298 |
+
|
| 299 |
+
# we make sure the constraint is of the form:
|
| 300 |
+
# c = b where b is a boolean expression
|
| 301 |
+
# and we consider b (constraint.rhs) for transformation
|
| 302 |
+
assert isinstance(condition_constraint.lhs, BVar)
|
| 303 |
+
assert is_bool_expr(condition_constraint.rhs)
|
| 304 |
+
condition_constraint_rhs = condition_constraint.rhs
|
| 305 |
+
|
| 306 |
+
# transform the condition constraint
|
| 307 |
+
condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
|
| 308 |
+
|
| 309 |
+
transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
|
| 310 |
+
|
| 311 |
+
transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
|
| 312 |
+
|
| 313 |
+
negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
|
| 314 |
+
|
| 315 |
+
return z3.And([transformed, transformed_condition_constraint]), \
|
| 316 |
+
z3.And([transformed, negation_transformed_condition_constraint])
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
|
| 320 |
+
"""
|
| 321 |
+
Given an IR and a node representing a conditional, evaluate the conditional
|
| 322 |
+
and its negation
|
| 323 |
+
Args:
|
| 324 |
+
tracer_root: Tracer root for module instances
|
| 325 |
+
node: The node to be evaluated
|
| 326 |
+
|
| 327 |
+
Returns: the results of evaluating the condition and the negation with
|
| 328 |
+
the rest of the constraints
|
| 329 |
+
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
transformed_positive, transformed_negative = \
|
| 333 |
+
transform_all_constraints_trace_time(tracer_root, graph, node, counter)
|
| 334 |
+
|
| 335 |
+
s = z3.Solver()
|
| 336 |
+
s.add(transformed_positive)
|
| 337 |
+
if user_constraints is not None:
|
| 338 |
+
s.add(user_constraints)
|
| 339 |
+
condition = s.check()
|
| 340 |
+
|
| 341 |
+
s = z3.Solver()
|
| 342 |
+
s.add(transformed_negative)
|
| 343 |
+
if user_constraints is not None:
|
| 344 |
+
s.add(user_constraints)
|
| 345 |
+
negation = s.check()
|
| 346 |
+
return condition, negation
|
| 347 |
+
|
| 348 |
+
except ImportError:
|
| 349 |
+
HAS_Z3 = False
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
|
| 3 |
+
BVar
|
| 4 |
+
from torch.fx.experimental.migrate_gradual_types.operation import op_leq
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def gen_tvar(curr):
|
| 8 |
+
"""
|
| 9 |
+
Generate a tensor variable
|
| 10 |
+
:param curr: The current counter
|
| 11 |
+
:return: a tensor variable and the updated counter
|
| 12 |
+
"""
|
| 13 |
+
curr += 1
|
| 14 |
+
return TVar(curr), curr
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def gen_dvar(curr):
|
| 18 |
+
"""
|
| 19 |
+
Generate a dimension variable
|
| 20 |
+
:param curr: the current counter
|
| 21 |
+
:return: a dimension variable and an updated counter
|
| 22 |
+
"""
|
| 23 |
+
curr += 1
|
| 24 |
+
return DVar(curr), curr
|
| 25 |
+
|
| 26 |
+
def gen_bvar(curr):
|
| 27 |
+
"""
|
| 28 |
+
Generate a boolean variable
|
| 29 |
+
:param curr: the current counter
|
| 30 |
+
:return: a boolean variable and an updated counter
|
| 31 |
+
"""
|
| 32 |
+
curr += 1
|
| 33 |
+
return BVar(curr), curr
|
| 34 |
+
|
| 35 |
+
def gen_tensor_dims(n, curr):
|
| 36 |
+
"""
|
| 37 |
+
Generate a list of tensor dimensions
|
| 38 |
+
:param n: the number of dimensions
|
| 39 |
+
:param curr: the current counter
|
| 40 |
+
:return: a list of dimension variables and an updated counter
|
| 41 |
+
"""
|
| 42 |
+
dims = []
|
| 43 |
+
for _ in range(n):
|
| 44 |
+
dvar, curr = gen_dvar(curr)
|
| 45 |
+
dims.append(dvar)
|
| 46 |
+
return dims, curr
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def gen_nat_constraints(list_of_dims):
|
| 50 |
+
"""
|
| 51 |
+
Generate natural number constraints for dimensions
|
| 52 |
+
"""
|
| 53 |
+
return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import z3 # type: ignore[import]
|
| 3 |
+
HAS_Z3 = True
|
| 4 |
+
# dynamic type
|
| 5 |
+
dyn = z3.DeclareSort('Dyn')
|
| 6 |
+
dyn_type = z3.Const('dyn', dyn)
|
| 7 |
+
|
| 8 |
+
# dimension
|
| 9 |
+
dim = z3.Datatype('dim')
|
| 10 |
+
dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
|
| 11 |
+
dim = dim.create()
|
| 12 |
+
|
| 13 |
+
# tensors
|
| 14 |
+
tensor_type = z3.Datatype('TensorType')
|
| 15 |
+
tensor_type.declare('Dyn', ('dyn', dyn))
|
| 16 |
+
tensor_type.declare('tensor1', ('0', dim))
|
| 17 |
+
tensor_type.declare('tensor2', ('0', dim), ('1', dim))
|
| 18 |
+
tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
|
| 19 |
+
tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
|
| 20 |
+
tensor_type = tensor_type.create()
|
| 21 |
+
|
| 22 |
+
# create dimension
|
| 23 |
+
D = dim.dim
|
| 24 |
+
|
| 25 |
+
z3_dyn = tensor_type.Dyn(dyn_type)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
except ImportError:
|
| 29 |
+
HAS_Z3 = False
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: disable-error-code=attr-defined
|
| 2 |
+
from .core import unify, reify # noqa: F403
|
| 3 |
+
from .more import unifiable # noqa: F403
|
| 4 |
+
from .variable import var, isvar, vars, variables, Var # noqa: F403
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (471 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc
ADDED
|
Binary file (402 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc
ADDED
|
Binary file (5.25 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc
ADDED
|
Binary file (4.42 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections.abc import Iterator # type: ignore[import]
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
from .unification_tools import assoc # type: ignore[import]
|
| 6 |
+
from .utils import transitive_get as walk
|
| 7 |
+
from .variable import isvar
|
| 8 |
+
from .dispatch import dispatch
|
| 9 |
+
|
| 10 |
+
__all__ = ["reify", "unify"]
|
| 11 |
+
|
| 12 |
+
###############
|
| 13 |
+
# Reification #
|
| 14 |
+
###############
|
| 15 |
+
|
| 16 |
+
@dispatch(Iterator, dict)
|
| 17 |
+
def _reify(t, s):
|
| 18 |
+
return map(partial(reify, s=s), t)
|
| 19 |
+
# return (reify(arg, s) for arg in t)
|
| 20 |
+
_reify
|
| 21 |
+
|
| 22 |
+
@dispatch(tuple, dict) # type: ignore[no-redef]
|
| 23 |
+
def _reify(t, s):
|
| 24 |
+
return tuple(reify(iter(t), s))
|
| 25 |
+
_reify
|
| 26 |
+
|
| 27 |
+
@dispatch(list, dict) # type: ignore[no-redef]
|
| 28 |
+
def _reify(t, s):
|
| 29 |
+
return list(reify(iter(t), s))
|
| 30 |
+
_reify
|
| 31 |
+
|
| 32 |
+
@dispatch(dict, dict) # type: ignore[no-redef]
|
| 33 |
+
def _reify(d, s):
|
| 34 |
+
return {k: reify(v, s) for k, v in d.items()}
|
| 35 |
+
_reify
|
| 36 |
+
|
| 37 |
+
@dispatch(object, dict) # type: ignore[no-redef]
|
| 38 |
+
def _reify(o, s):
|
| 39 |
+
return o # catch all, just return the object
|
| 40 |
+
|
| 41 |
+
def reify(e, s):
|
| 42 |
+
""" Replace variables of expression with substitution
|
| 43 |
+
>>> # xdoctest: +SKIP
|
| 44 |
+
>>> x, y = var(), var()
|
| 45 |
+
>>> e = (1, x, (3, y))
|
| 46 |
+
>>> s = {x: 2, y: 4}
|
| 47 |
+
>>> reify(e, s)
|
| 48 |
+
(1, 2, (3, 4))
|
| 49 |
+
>>> e = {1: x, 3: (y, 5)}
|
| 50 |
+
>>> reify(e, s)
|
| 51 |
+
{1: 2, 3: (4, 5)}
|
| 52 |
+
"""
|
| 53 |
+
if isvar(e):
|
| 54 |
+
return reify(s[e], s) if e in s else e
|
| 55 |
+
return _reify(e, s)
|
| 56 |
+
|
| 57 |
+
###############
|
| 58 |
+
# Unification #
|
| 59 |
+
###############
|
| 60 |
+
|
| 61 |
+
seq = tuple, list, Iterator
|
| 62 |
+
|
| 63 |
+
@dispatch(seq, seq, dict)
|
| 64 |
+
def _unify(u, v, s):
|
| 65 |
+
if len(u) != len(v):
|
| 66 |
+
return False
|
| 67 |
+
for uu, vv in zip(u, v): # avoiding recursion
|
| 68 |
+
s = unify(uu, vv, s)
|
| 69 |
+
if s is False:
|
| 70 |
+
return False
|
| 71 |
+
return s
|
| 72 |
+
#
|
| 73 |
+
# @dispatch((set, frozenset), (set, frozenset), dict)
|
| 74 |
+
# def _unify(u, v, s):
|
| 75 |
+
# i = u & v
|
| 76 |
+
# u = u - i
|
| 77 |
+
# v = v - i
|
| 78 |
+
# return _unify(sorted(u), sorted(v), s)
|
| 79 |
+
#
|
| 80 |
+
#
|
| 81 |
+
# @dispatch(dict, dict, dict)
|
| 82 |
+
# def _unify(u, v, s):
|
| 83 |
+
# if len(u) != len(v):
|
| 84 |
+
# return False
|
| 85 |
+
# for key, uval in iteritems(u):
|
| 86 |
+
# if key not in v:
|
| 87 |
+
# return False
|
| 88 |
+
# s = unify(uval, v[key], s)
|
| 89 |
+
# if s is False:
|
| 90 |
+
# return False
|
| 91 |
+
# return s
|
| 92 |
+
#
|
| 93 |
+
#
|
| 94 |
+
# @dispatch(object, object, dict)
|
| 95 |
+
# def _unify(u, v, s):
|
| 96 |
+
# return False # catch all
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dispatch(object, object, dict)
|
| 100 |
+
def unify(u, v, s): # no check at the moment
|
| 101 |
+
""" Find substitution so that u == v while satisfying s
|
| 102 |
+
>>> x = var('x')
|
| 103 |
+
>>> unify((1, x), (1, 2), {})
|
| 104 |
+
{~x: 2}
|
| 105 |
+
"""
|
| 106 |
+
u = walk(u, s)
|
| 107 |
+
v = walk(v, s)
|
| 108 |
+
if u == v:
|
| 109 |
+
return s
|
| 110 |
+
if isvar(u):
|
| 111 |
+
return assoc(s, u, v)
|
| 112 |
+
if isvar(v):
|
| 113 |
+
return assoc(s, v, u)
|
| 114 |
+
return _unify(u, v, s)
|
| 115 |
+
unify
|
| 116 |
+
|
| 117 |
+
@dispatch(object, object) # type: ignore[no-redef]
|
| 118 |
+
def unify(u, v):
|
| 119 |
+
return unify(u, v, {})
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from .multipledispatch import dispatch # type: ignore[import]
|
| 3 |
+
|
| 4 |
+
namespace = {} # type: ignore[var-annotated]
|
| 5 |
+
|
| 6 |
+
dispatch = partial(dispatch, namespace=namespace)
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from .core import unify, reify # type: ignore[attr-defined]
|
| 3 |
+
from .variable import isvar
|
| 4 |
+
from .utils import _toposort, freeze
|
| 5 |
+
from .unification_tools import groupby, first # type: ignore[import]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Dispatcher:
|
| 9 |
+
def __init__(self, name):
|
| 10 |
+
self.name = name
|
| 11 |
+
self.funcs = {}
|
| 12 |
+
self.ordering = []
|
| 13 |
+
|
| 14 |
+
def add(self, signature, func):
|
| 15 |
+
self.funcs[freeze(signature)] = func
|
| 16 |
+
self.ordering = ordering(self.funcs)
|
| 17 |
+
|
| 18 |
+
def __call__(self, *args, **kwargs):
|
| 19 |
+
func, s = self.resolve(args)
|
| 20 |
+
return func(*args, **kwargs)
|
| 21 |
+
|
| 22 |
+
def resolve(self, args):
|
| 23 |
+
n = len(args)
|
| 24 |
+
for signature in self.ordering:
|
| 25 |
+
if len(signature) != n:
|
| 26 |
+
continue
|
| 27 |
+
s = unify(freeze(args), signature)
|
| 28 |
+
if s is not False:
|
| 29 |
+
result = self.funcs[signature]
|
| 30 |
+
return result, s
|
| 31 |
+
raise NotImplementedError("No match found. \nKnown matches: "
|
| 32 |
+
+ str(self.ordering) + "\nInput: " + str(args))
|
| 33 |
+
|
| 34 |
+
def register(self, *signature):
|
| 35 |
+
def _(func):
|
| 36 |
+
self.add(signature, func)
|
| 37 |
+
return self
|
| 38 |
+
return _
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class VarDispatcher(Dispatcher):
|
| 42 |
+
""" A dispatcher that calls functions with variable names
|
| 43 |
+
>>> # xdoctest: +SKIP
|
| 44 |
+
>>> d = VarDispatcher('d')
|
| 45 |
+
>>> x = var('x')
|
| 46 |
+
>>> @d.register('inc', x)
|
| 47 |
+
... def f(x):
|
| 48 |
+
... return x + 1
|
| 49 |
+
>>> @d.register('double', x)
|
| 50 |
+
... def f(x):
|
| 51 |
+
... return x * 2
|
| 52 |
+
>>> d('inc', 10)
|
| 53 |
+
11
|
| 54 |
+
>>> d('double', 10)
|
| 55 |
+
20
|
| 56 |
+
"""
|
| 57 |
+
def __call__(self, *args, **kwargs):
|
| 58 |
+
func, s = self.resolve(args)
|
| 59 |
+
d = {k.token: v for k, v in s.items()}
|
| 60 |
+
return func(**d)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
global_namespace = {} # type: ignore[var-annotated]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def match(*signature, **kwargs):
|
| 67 |
+
namespace = kwargs.get('namespace', global_namespace)
|
| 68 |
+
dispatcher = kwargs.get('Dispatcher', Dispatcher)
|
| 69 |
+
|
| 70 |
+
def _(func):
|
| 71 |
+
name = func.__name__
|
| 72 |
+
|
| 73 |
+
if name not in namespace:
|
| 74 |
+
namespace[name] = dispatcher(name)
|
| 75 |
+
d = namespace[name]
|
| 76 |
+
|
| 77 |
+
d.add(signature, func)
|
| 78 |
+
|
| 79 |
+
return d
|
| 80 |
+
return _
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def supercedes(a, b):
|
| 84 |
+
""" ``a`` is a more specific match than ``b`` """
|
| 85 |
+
if isvar(b) and not isvar(a):
|
| 86 |
+
return True
|
| 87 |
+
s = unify(a, b)
|
| 88 |
+
if s is False:
|
| 89 |
+
return False
|
| 90 |
+
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
|
| 91 |
+
if reify(a, s) == a:
|
| 92 |
+
return True
|
| 93 |
+
if reify(b, s) == b:
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Taken from multipledispatch
|
| 98 |
+
def edge(a, b, tie_breaker=hash):
|
| 99 |
+
""" A should be checked before B
|
| 100 |
+
Tie broken by tie_breaker, defaults to ``hash``
|
| 101 |
+
"""
|
| 102 |
+
if supercedes(a, b):
|
| 103 |
+
if supercedes(b, a):
|
| 104 |
+
return tie_breaker(a) > tie_breaker(b)
|
| 105 |
+
else:
|
| 106 |
+
return True
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Taken from multipledispatch
|
| 111 |
+
def ordering(signatures):
|
| 112 |
+
""" A sane ordering of signatures to check, first to last
|
| 113 |
+
Topological sort of edges as given by ``edge`` and ``supercedes``
|
| 114 |
+
"""
|
| 115 |
+
signatures = list(map(tuple, signatures))
|
| 116 |
+
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
|
| 117 |
+
edges = groupby(first, edges)
|
| 118 |
+
for s in signatures:
|
| 119 |
+
if s not in edges:
|
| 120 |
+
edges[s] = []
|
| 121 |
+
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
|
| 122 |
+
return _toposort(edges)
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from .core import unify, reify # type: ignore[attr-defined]
|
| 3 |
+
from .dispatch import dispatch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def unifiable(cls):
|
| 7 |
+
""" Register standard unify and reify operations on class
|
| 8 |
+
This uses the type and __dict__ or __slots__ attributes to define the
|
| 9 |
+
nature of the term
|
| 10 |
+
See Also:
|
| 11 |
+
>>> # xdoctest: +SKIP
|
| 12 |
+
>>> class A(object):
|
| 13 |
+
... def __init__(self, a, b):
|
| 14 |
+
... self.a = a
|
| 15 |
+
... self.b = b
|
| 16 |
+
>>> unifiable(A)
|
| 17 |
+
<class 'unification.more.A'>
|
| 18 |
+
>>> x = var('x')
|
| 19 |
+
>>> a = A(1, 2)
|
| 20 |
+
>>> b = A(1, x)
|
| 21 |
+
>>> unify(a, b, {})
|
| 22 |
+
{~x: 2}
|
| 23 |
+
"""
|
| 24 |
+
_unify.add((cls, cls, dict), unify_object)
|
| 25 |
+
_reify.add((cls, dict), reify_object)
|
| 26 |
+
|
| 27 |
+
return cls
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
#########
|
| 31 |
+
# Reify #
|
| 32 |
+
#########
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def reify_object(o, s):
|
| 36 |
+
""" Reify a Python object with a substitution
|
| 37 |
+
>>> # xdoctest: +SKIP
|
| 38 |
+
>>> class Foo(object):
|
| 39 |
+
... def __init__(self, a, b):
|
| 40 |
+
... self.a = a
|
| 41 |
+
... self.b = b
|
| 42 |
+
... def __str__(self):
|
| 43 |
+
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
|
| 44 |
+
>>> x = var('x')
|
| 45 |
+
>>> f = Foo(1, x)
|
| 46 |
+
>>> print(f)
|
| 47 |
+
Foo(1, ~x)
|
| 48 |
+
>>> print(reify_object(f, {x: 2}))
|
| 49 |
+
Foo(1, 2)
|
| 50 |
+
"""
|
| 51 |
+
if hasattr(o, '__slots__'):
|
| 52 |
+
return _reify_object_slots(o, s)
|
| 53 |
+
else:
|
| 54 |
+
return _reify_object_dict(o, s)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _reify_object_dict(o, s):
|
| 58 |
+
obj = object.__new__(type(o))
|
| 59 |
+
d = reify(o.__dict__, s)
|
| 60 |
+
if d == o.__dict__:
|
| 61 |
+
return o
|
| 62 |
+
obj.__dict__.update(d)
|
| 63 |
+
return obj
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _reify_object_slots(o, s):
|
| 67 |
+
attrs = [getattr(o, attr) for attr in o.__slots__]
|
| 68 |
+
new_attrs = reify(attrs, s)
|
| 69 |
+
if attrs == new_attrs:
|
| 70 |
+
return o
|
| 71 |
+
else:
|
| 72 |
+
newobj = object.__new__(type(o))
|
| 73 |
+
for slot, attr in zip(o.__slots__, new_attrs):
|
| 74 |
+
setattr(newobj, slot, attr)
|
| 75 |
+
return newobj
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dispatch(slice, dict)
|
| 79 |
+
def _reify(o, s):
|
| 80 |
+
""" Reify a Python ``slice`` object """
|
| 81 |
+
return slice(*reify((o.start, o.stop, o.step), s))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
#########
|
| 85 |
+
# Unify #
|
| 86 |
+
#########
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def unify_object(u, v, s):
|
| 90 |
+
""" Unify two Python objects
|
| 91 |
+
Unifies their type and ``__dict__`` attributes
|
| 92 |
+
>>> # xdoctest: +SKIP
|
| 93 |
+
>>> class Foo(object):
|
| 94 |
+
... def __init__(self, a, b):
|
| 95 |
+
... self.a = a
|
| 96 |
+
... self.b = b
|
| 97 |
+
... def __str__(self):
|
| 98 |
+
... return "Foo(%s, %s)"%(str(self.a), str(self.b))
|
| 99 |
+
>>> x = var('x')
|
| 100 |
+
>>> f = Foo(1, x)
|
| 101 |
+
>>> g = Foo(1, 2)
|
| 102 |
+
>>> unify_object(f, g, {})
|
| 103 |
+
{~x: 2}
|
| 104 |
+
"""
|
| 105 |
+
if type(u) != type(v):
|
| 106 |
+
return False
|
| 107 |
+
if hasattr(u, '__slots__'):
|
| 108 |
+
return unify([getattr(u, slot) for slot in u.__slots__],
|
| 109 |
+
[getattr(v, slot) for slot in v.__slots__],
|
| 110 |
+
s)
|
| 111 |
+
else:
|
| 112 |
+
return unify(u.__dict__, v.__dict__, s)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dispatch(slice, slice, dict)
|
| 116 |
+
def _unify(u, v, s):
|
| 117 |
+
""" Unify a Python ``slice`` object """
|
| 118 |
+
return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .core import dispatch
|
| 2 |
+
from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
|
| 3 |
+
MDNotImplementedError)
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (464 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc
ADDED
|
Binary file (8.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (6.28 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc
ADDED
|
Binary file (4.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from .utils import _toposort, groupby
|
| 3 |
+
from .variadic import isvariadic
|
| 4 |
+
import operator
|
| 5 |
+
|
| 6 |
+
__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
|
| 7 |
+
"edge", "ordering"]
|
| 8 |
+
|
| 9 |
+
class AmbiguityWarning(Warning):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def supercedes(a, b):
|
| 14 |
+
""" A is consistent and strictly more specific than B """
|
| 15 |
+
if len(a) < len(b):
|
| 16 |
+
# only case is if a is empty and b is variadic
|
| 17 |
+
return not a and len(b) == 1 and isvariadic(b[-1])
|
| 18 |
+
elif len(a) == len(b):
|
| 19 |
+
return all(map(issubclass, a, b))
|
| 20 |
+
else:
|
| 21 |
+
# len(a) > len(b)
|
| 22 |
+
p1 = 0
|
| 23 |
+
p2 = 0
|
| 24 |
+
while p1 < len(a) and p2 < len(b):
|
| 25 |
+
cur_a = a[p1]
|
| 26 |
+
cur_b = b[p2]
|
| 27 |
+
if not (isvariadic(cur_a) or isvariadic(cur_b)):
|
| 28 |
+
if not issubclass(cur_a, cur_b):
|
| 29 |
+
return False
|
| 30 |
+
p1 += 1
|
| 31 |
+
p2 += 1
|
| 32 |
+
elif isvariadic(cur_a):
|
| 33 |
+
assert p1 == len(a) - 1
|
| 34 |
+
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
|
| 35 |
+
elif isvariadic(cur_b):
|
| 36 |
+
assert p2 == len(b) - 1
|
| 37 |
+
if not issubclass(cur_a, cur_b):
|
| 38 |
+
return False
|
| 39 |
+
p1 += 1
|
| 40 |
+
return p2 == len(b) - 1 and p1 == len(a)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def consistent(a, b):
|
| 44 |
+
""" It is possible for an argument list to satisfy both A and B """
|
| 45 |
+
|
| 46 |
+
# Need to check for empty args
|
| 47 |
+
if not a:
|
| 48 |
+
return not b or isvariadic(b[0])
|
| 49 |
+
if not b:
|
| 50 |
+
return not a or isvariadic(a[0])
|
| 51 |
+
|
| 52 |
+
# Non-empty args check for mutual subclasses
|
| 53 |
+
if len(a) == len(b):
|
| 54 |
+
return all(issubclass(aa, bb) or issubclass(bb, aa)
|
| 55 |
+
for aa, bb in zip(a, b))
|
| 56 |
+
else:
|
| 57 |
+
p1 = 0
|
| 58 |
+
p2 = 0
|
| 59 |
+
while p1 < len(a) and p2 < len(b):
|
| 60 |
+
cur_a = a[p1]
|
| 61 |
+
cur_b = b[p2]
|
| 62 |
+
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
|
| 63 |
+
return False
|
| 64 |
+
if not (isvariadic(cur_a) or isvariadic(cur_b)):
|
| 65 |
+
p1 += 1
|
| 66 |
+
p2 += 1
|
| 67 |
+
elif isvariadic(cur_a):
|
| 68 |
+
p2 += 1
|
| 69 |
+
elif isvariadic(cur_b):
|
| 70 |
+
p1 += 1
|
| 71 |
+
# We only need to check for variadic ends
|
| 72 |
+
# Variadic types are guaranteed to be the last element
|
| 73 |
+
return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
|
| 74 |
+
isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def ambiguous(a, b):
|
| 78 |
+
""" A is consistent with B but neither is strictly more specific """
|
| 79 |
+
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def ambiguities(signatures):
|
| 83 |
+
""" All signature pairs such that A is ambiguous with B """
|
| 84 |
+
signatures = list(map(tuple, signatures))
|
| 85 |
+
return {(a, b) for a in signatures for b in signatures
|
| 86 |
+
if hash(a) < hash(b)
|
| 87 |
+
and ambiguous(a, b)
|
| 88 |
+
and not any(supercedes(c, a) and supercedes(c, b)
|
| 89 |
+
for c in signatures)}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def super_signature(signatures):
|
| 93 |
+
""" A signature that would break ambiguities """
|
| 94 |
+
n = len(signatures[0])
|
| 95 |
+
assert all(len(s) == n for s in signatures)
|
| 96 |
+
|
| 97 |
+
return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
|
| 98 |
+
for i in range(n)]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def edge(a, b, tie_breaker=hash):
|
| 102 |
+
""" A should be checked before B
|
| 103 |
+
Tie broken by tie_breaker, defaults to ``hash``
|
| 104 |
+
"""
|
| 105 |
+
# A either supercedes B and B does not supercede A or if B does then call
|
| 106 |
+
# tie_breaker
|
| 107 |
+
return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def ordering(signatures):
|
| 111 |
+
""" A sane ordering of signatures to check, first to last
|
| 112 |
+
Topological sort of edges as given by ``edge`` and ``supercedes``
|
| 113 |
+
"""
|
| 114 |
+
signatures = list(map(tuple, signatures))
|
| 115 |
+
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
|
| 116 |
+
edges = groupby(operator.itemgetter(0), edges)
|
| 117 |
+
for s in signatures:
|
| 118 |
+
if s not in edges:
|
| 119 |
+
edges[s] = []
|
| 120 |
+
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
|
| 121 |
+
return _toposort(edges)
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import inspect
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from .dispatcher import Dispatcher, MethodDispatcher
|
| 6 |
+
|
| 7 |
+
global_namespace = {} # type: ignore[var-annotated]
|
| 8 |
+
|
| 9 |
+
__all__ = ["dispatch", "ismethod"]
|
| 10 |
+
|
| 11 |
+
def dispatch(*types, **kwargs):
|
| 12 |
+
""" Dispatch function on the types of the inputs
|
| 13 |
+
Supports dispatch on all non-keyword arguments.
|
| 14 |
+
Collects implementations based on the function name. Ignores namespaces.
|
| 15 |
+
If ambiguous type signatures occur a warning is raised when the function is
|
| 16 |
+
defined suggesting the additional method to break the ambiguity.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> # xdoctest: +SKIP
|
| 20 |
+
>>> @dispatch(int)
|
| 21 |
+
... def f(x):
|
| 22 |
+
... return x + 1
|
| 23 |
+
>>> @dispatch(float)
|
| 24 |
+
... def f(x):
|
| 25 |
+
... return x - 1
|
| 26 |
+
>>> # xdoctest: +SKIP
|
| 27 |
+
>>> f(3)
|
| 28 |
+
4
|
| 29 |
+
>>> f(3.0)
|
| 30 |
+
2.0
|
| 31 |
+
>>> # Specify an isolated namespace with the namespace keyword argument
|
| 32 |
+
>>> my_namespace = {}
|
| 33 |
+
>>> @dispatch(int, namespace=my_namespace)
|
| 34 |
+
... def foo(x):
|
| 35 |
+
... return x + 1
|
| 36 |
+
>>> # Dispatch on instance methods within classes
|
| 37 |
+
>>> class MyClass(object):
|
| 38 |
+
... @dispatch(list)
|
| 39 |
+
... def __init__(self, data):
|
| 40 |
+
... self.data = data
|
| 41 |
+
... @dispatch(int)
|
| 42 |
+
... def __init__(self, datum):
|
| 43 |
+
... self.data = [datum]
|
| 44 |
+
>>> MyClass([1, 2, 3]).data
|
| 45 |
+
[1, 2, 3]
|
| 46 |
+
>>> MyClass(3).data
|
| 47 |
+
[3]
|
| 48 |
+
"""
|
| 49 |
+
namespace = kwargs.get('namespace', global_namespace)
|
| 50 |
+
|
| 51 |
+
types = tuple(types)
|
| 52 |
+
|
| 53 |
+
def _df(func):
|
| 54 |
+
name = func.__name__
|
| 55 |
+
|
| 56 |
+
if ismethod(func):
|
| 57 |
+
dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr]
|
| 58 |
+
name, # type: ignore[union-attr]
|
| 59 |
+
MethodDispatcher(name),
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
if name not in namespace:
|
| 63 |
+
namespace[name] = Dispatcher(name)
|
| 64 |
+
dispatcher = namespace[name]
|
| 65 |
+
|
| 66 |
+
dispatcher.add(types, func)
|
| 67 |
+
return dispatcher
|
| 68 |
+
return _df
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ismethod(func):
|
| 72 |
+
""" Is func a method?
|
| 73 |
+
Note that this has to work as the method is defined but before the class is
|
| 74 |
+
defined. At this stage methods look like functions.
|
| 75 |
+
"""
|
| 76 |
+
if hasattr(inspect, "signature"):
|
| 77 |
+
signature = inspect.signature(func)
|
| 78 |
+
return signature.parameters.get('self', None) is not None
|
| 79 |
+
else:
|
| 80 |
+
if sys.version_info.major < 3:
|
| 81 |
+
spec = inspect.getargspec(func) # type: ignore[attr-defined]
|
| 82 |
+
else:
|
| 83 |
+
spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
|
| 84 |
+
return spec and spec.args and spec.args[0] == 'self'
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from warnings import warn
|
| 3 |
+
import inspect
|
| 4 |
+
from typing_extensions import deprecated
|
| 5 |
+
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
|
| 6 |
+
from .utils import expand_tuples
|
| 7 |
+
from .variadic import Variadic, isvariadic
|
| 8 |
+
import itertools as itl
|
| 9 |
+
|
| 10 |
+
__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
|
| 11 |
+
"variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
|
| 12 |
+
|
| 13 |
+
class MDNotImplementedError(NotImplementedError):
|
| 14 |
+
""" A NotImplementedError for multiple dispatch """
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ambiguity_warn(dispatcher, ambiguities):
|
| 18 |
+
""" Raise warning when ambiguity is detected
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
dispatcher : Dispatcher
|
| 22 |
+
The dispatcher on which the ambiguity was detected
|
| 23 |
+
ambiguities : set
|
| 24 |
+
Set of type signature pairs that are ambiguous within this dispatcher
|
| 25 |
+
See Also:
|
| 26 |
+
Dispatcher.add
|
| 27 |
+
warning_text
|
| 28 |
+
"""
|
| 29 |
+
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@deprecated(
|
| 33 |
+
"`halt_ordering` is deprecated, you can safely remove this call.",
|
| 34 |
+
category=FutureWarning,
|
| 35 |
+
)
|
| 36 |
+
def halt_ordering():
|
| 37 |
+
"""Deprecated interface to temporarily disable ordering."""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@deprecated(
|
| 41 |
+
"`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
|
| 42 |
+
"you should call the `reorder()` method on each dispatcher.",
|
| 43 |
+
category=FutureWarning,
|
| 44 |
+
)
|
| 45 |
+
def restart_ordering(on_ambiguity=ambiguity_warn):
|
| 46 |
+
"""Deprecated interface to temporarily resume ordering."""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def variadic_signature_matches_iter(types, full_signature):
|
| 50 |
+
"""Check if a set of input types matches a variadic signature.
|
| 51 |
+
Notes
|
| 52 |
+
-----
|
| 53 |
+
The algorithm is as follows:
|
| 54 |
+
Initialize the current signature to the first in the sequence
|
| 55 |
+
For each type in `types`:
|
| 56 |
+
If the current signature is variadic
|
| 57 |
+
If the type matches the signature
|
| 58 |
+
yield True
|
| 59 |
+
Else
|
| 60 |
+
Try to get the next signature
|
| 61 |
+
If no signatures are left we can't possibly have a match
|
| 62 |
+
so yield False
|
| 63 |
+
Else
|
| 64 |
+
yield True if the type matches the current signature
|
| 65 |
+
Get the next signature
|
| 66 |
+
"""
|
| 67 |
+
sigiter = iter(full_signature)
|
| 68 |
+
sig = next(sigiter)
|
| 69 |
+
for typ in types:
|
| 70 |
+
matches = issubclass(typ, sig)
|
| 71 |
+
yield matches
|
| 72 |
+
if not isvariadic(sig):
|
| 73 |
+
# we're not matching a variadic argument, so move to the next
|
| 74 |
+
# element in the signature
|
| 75 |
+
sig = next(sigiter)
|
| 76 |
+
else:
|
| 77 |
+
try:
|
| 78 |
+
sig = next(sigiter)
|
| 79 |
+
except StopIteration:
|
| 80 |
+
assert isvariadic(sig)
|
| 81 |
+
yield True
|
| 82 |
+
else:
|
| 83 |
+
# We have signature items left over, so all of our arguments
|
| 84 |
+
# haven't matched
|
| 85 |
+
yield False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def variadic_signature_matches(types, full_signature):
|
| 89 |
+
# No arguments always matches a variadic signature
|
| 90 |
+
assert full_signature
|
| 91 |
+
return all(variadic_signature_matches_iter(types, full_signature))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Dispatcher:
|
| 95 |
+
""" Dispatch methods based on type signature
|
| 96 |
+
Use ``dispatch`` to add implementations
|
| 97 |
+
Examples
|
| 98 |
+
--------
|
| 99 |
+
>>> # xdoctest: +SKIP("bad import name")
|
| 100 |
+
>>> from multipledispatch import dispatch
|
| 101 |
+
>>> @dispatch(int)
|
| 102 |
+
... def f(x):
|
| 103 |
+
... return x + 1
|
| 104 |
+
>>> @dispatch(float)
|
| 105 |
+
... def f(x):
|
| 106 |
+
... return x - 1
|
| 107 |
+
>>> f(3)
|
| 108 |
+
4
|
| 109 |
+
>>> f(3.0)
|
| 110 |
+
2.0
|
| 111 |
+
"""
|
| 112 |
+
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
|
| 113 |
+
|
| 114 |
+
def __init__(self, name, doc=None):
|
| 115 |
+
self.name = self.__name__ = name
|
| 116 |
+
self.funcs = {}
|
| 117 |
+
self.doc = doc
|
| 118 |
+
|
| 119 |
+
self._cache = {}
|
| 120 |
+
|
| 121 |
+
def register(self, *types, **kwargs):
|
| 122 |
+
""" register dispatcher with new implementation
|
| 123 |
+
>>> # xdoctest: +SKIP
|
| 124 |
+
>>> f = Dispatcher('f')
|
| 125 |
+
>>> @f.register(int)
|
| 126 |
+
... def inc(x):
|
| 127 |
+
... return x + 1
|
| 128 |
+
>>> @f.register(float)
|
| 129 |
+
... def dec(x):
|
| 130 |
+
... return x - 1
|
| 131 |
+
>>> @f.register(list)
|
| 132 |
+
... @f.register(tuple)
|
| 133 |
+
... def reverse(x):
|
| 134 |
+
... return x[::-1]
|
| 135 |
+
>>> f(1)
|
| 136 |
+
2
|
| 137 |
+
>>> f(1.0)
|
| 138 |
+
0.0
|
| 139 |
+
>>> f([1, 2, 3])
|
| 140 |
+
[3, 2, 1]
|
| 141 |
+
"""
|
| 142 |
+
def _df(func):
|
| 143 |
+
self.add(types, func, **kwargs) # type: ignore[call-arg]
|
| 144 |
+
return func
|
| 145 |
+
return _df
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def get_func_params(cls, func):
|
| 149 |
+
if hasattr(inspect, "signature"):
|
| 150 |
+
sig = inspect.signature(func)
|
| 151 |
+
return sig.parameters.values()
|
| 152 |
+
|
| 153 |
+
@classmethod
|
| 154 |
+
def get_func_annotations(cls, func):
|
| 155 |
+
""" get annotations of function positional parameters
|
| 156 |
+
"""
|
| 157 |
+
params = cls.get_func_params(func)
|
| 158 |
+
if params:
|
| 159 |
+
Parameter = inspect.Parameter
|
| 160 |
+
|
| 161 |
+
params = (param for param in params
|
| 162 |
+
if param.kind in
|
| 163 |
+
(Parameter.POSITIONAL_ONLY,
|
| 164 |
+
Parameter.POSITIONAL_OR_KEYWORD))
|
| 165 |
+
|
| 166 |
+
annotations = tuple(
|
| 167 |
+
param.annotation
|
| 168 |
+
for param in params)
|
| 169 |
+
|
| 170 |
+
if all(ann is not Parameter.empty for ann in annotations):
|
| 171 |
+
return annotations
|
| 172 |
+
|
| 173 |
+
def add(self, signature, func):
|
| 174 |
+
""" Add new types/method pair to dispatcher
|
| 175 |
+
>>> # xdoctest: +SKIP
|
| 176 |
+
>>> D = Dispatcher('add')
|
| 177 |
+
>>> D.add((int, int), lambda x, y: x + y)
|
| 178 |
+
>>> D.add((float, float), lambda x, y: x + y)
|
| 179 |
+
>>> D(1, 2)
|
| 180 |
+
3
|
| 181 |
+
>>> D(1, 2.0)
|
| 182 |
+
Traceback (most recent call last):
|
| 183 |
+
...
|
| 184 |
+
NotImplementedError: Could not find signature for add: <int, float>
|
| 185 |
+
>>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
|
| 186 |
+
>>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
|
| 187 |
+
>>> # as inputs. See ``ambiguity_warn`` for an example.
|
| 188 |
+
"""
|
| 189 |
+
# Handle annotations
|
| 190 |
+
if not signature:
|
| 191 |
+
annotations = self.get_func_annotations(func)
|
| 192 |
+
if annotations:
|
| 193 |
+
signature = annotations
|
| 194 |
+
|
| 195 |
+
# Handle union types
|
| 196 |
+
if any(isinstance(typ, tuple) for typ in signature):
|
| 197 |
+
for typs in expand_tuples(signature):
|
| 198 |
+
self.add(typs, func)
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
new_signature = []
|
| 202 |
+
|
| 203 |
+
for index, typ in enumerate(signature, start=1):
|
| 204 |
+
if not isinstance(typ, (type, list)):
|
| 205 |
+
str_sig = ', '.join(c.__name__ if isinstance(c, type)
|
| 206 |
+
else str(c) for c in signature)
|
| 207 |
+
raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
|
| 208 |
+
f"In signature: <{str_sig}>\n"
|
| 209 |
+
f"In function: {self.name}")
|
| 210 |
+
|
| 211 |
+
# handle variadic signatures
|
| 212 |
+
if isinstance(typ, list):
|
| 213 |
+
if index != len(signature):
|
| 214 |
+
raise TypeError(
|
| 215 |
+
'Variadic signature must be the last element'
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if len(typ) != 1:
|
| 219 |
+
raise TypeError(
|
| 220 |
+
'Variadic signature must contain exactly one element. '
|
| 221 |
+
'To use a variadic union type place the desired types '
|
| 222 |
+
'inside of a tuple, e.g., [(int, str)]'
|
| 223 |
+
)
|
| 224 |
+
new_signature.append(Variadic[typ[0]])
|
| 225 |
+
else:
|
| 226 |
+
new_signature.append(typ)
|
| 227 |
+
|
| 228 |
+
self.funcs[tuple(new_signature)] = func
|
| 229 |
+
self._cache.clear()
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
del self._ordering
|
| 233 |
+
except AttributeError:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def ordering(self):
|
| 238 |
+
try:
|
| 239 |
+
return self._ordering
|
| 240 |
+
except AttributeError:
|
| 241 |
+
return self.reorder()
|
| 242 |
+
|
| 243 |
+
def reorder(self, on_ambiguity=ambiguity_warn):
|
| 244 |
+
self._ordering = od = ordering(self.funcs)
|
| 245 |
+
amb = ambiguities(self.funcs)
|
| 246 |
+
if amb:
|
| 247 |
+
on_ambiguity(self, amb)
|
| 248 |
+
return od
|
| 249 |
+
|
| 250 |
+
def __call__(self, *args, **kwargs):
|
| 251 |
+
types = tuple([type(arg) for arg in args])
|
| 252 |
+
try:
|
| 253 |
+
func = self._cache[types]
|
| 254 |
+
except KeyError as e:
|
| 255 |
+
func = self.dispatch(*types)
|
| 256 |
+
if not func:
|
| 257 |
+
raise NotImplementedError(
|
| 258 |
+
f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
|
| 259 |
+
self._cache[types] = func
|
| 260 |
+
try:
|
| 261 |
+
return func(*args, **kwargs)
|
| 262 |
+
|
| 263 |
+
except MDNotImplementedError as e:
|
| 264 |
+
funcs = self.dispatch_iter(*types)
|
| 265 |
+
next(funcs) # burn first
|
| 266 |
+
for func in funcs:
|
| 267 |
+
try:
|
| 268 |
+
return func(*args, **kwargs)
|
| 269 |
+
except MDNotImplementedError:
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
raise NotImplementedError(
|
| 273 |
+
"Matching functions for "
|
| 274 |
+
f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
|
| 275 |
+
|
| 276 |
+
def __str__(self):
|
| 277 |
+
return f"<dispatched {self.name}>"
|
| 278 |
+
__repr__ = __str__
|
| 279 |
+
|
| 280 |
+
def dispatch(self, *types):
|
| 281 |
+
"""Determine appropriate implementation for this type signature
|
| 282 |
+
This method is internal. Users should call this object as a function.
|
| 283 |
+
Implementation resolution occurs within the ``__call__`` method.
|
| 284 |
+
>>> # xdoctest: +SKIP
|
| 285 |
+
>>> from multipledispatch import dispatch
|
| 286 |
+
>>> @dispatch(int)
|
| 287 |
+
... def inc(x):
|
| 288 |
+
... return x + 1
|
| 289 |
+
>>> implementation = inc.dispatch(int)
|
| 290 |
+
>>> implementation(3)
|
| 291 |
+
4
|
| 292 |
+
>>> print(inc.dispatch(float))
|
| 293 |
+
None
|
| 294 |
+
See Also:
|
| 295 |
+
``multipledispatch.conflict`` - module to determine resolution order
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
if types in self.funcs:
|
| 299 |
+
return self.funcs[types]
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
return next(self.dispatch_iter(*types))
|
| 303 |
+
except StopIteration:
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
+
def dispatch_iter(self, *types):
|
| 307 |
+
|
| 308 |
+
n = len(types)
|
| 309 |
+
for signature in self.ordering:
|
| 310 |
+
if len(signature) == n and all(map(issubclass, types, signature)):
|
| 311 |
+
result = self.funcs[signature]
|
| 312 |
+
yield result
|
| 313 |
+
elif len(signature) and isvariadic(signature[-1]):
|
| 314 |
+
if variadic_signature_matches(types, signature):
|
| 315 |
+
result = self.funcs[signature]
|
| 316 |
+
yield result
|
| 317 |
+
|
| 318 |
+
@deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
|
| 319 |
+
def resolve(self, types):
|
| 320 |
+
""" Determine appropriate implementation for this type signature
|
| 321 |
+
.. deprecated:: 0.4.4
|
| 322 |
+
Use ``dispatch(*types)`` instead
|
| 323 |
+
"""
|
| 324 |
+
return self.dispatch(*types)
|
| 325 |
+
|
| 326 |
+
def __getstate__(self):
|
| 327 |
+
return {'name': self.name,
|
| 328 |
+
'funcs': self.funcs}
|
| 329 |
+
|
| 330 |
+
def __setstate__(self, d):
|
| 331 |
+
self.name = d['name']
|
| 332 |
+
self.funcs = d['funcs']
|
| 333 |
+
self._ordering = ordering(self.funcs)
|
| 334 |
+
self._cache = {}
|
| 335 |
+
|
| 336 |
+
@property
|
| 337 |
+
def __doc__(self):
|
| 338 |
+
docs = [f"Multiply dispatched method: {self.name}"]
|
| 339 |
+
|
| 340 |
+
if self.doc:
|
| 341 |
+
docs.append(self.doc)
|
| 342 |
+
|
| 343 |
+
other = []
|
| 344 |
+
for sig in self.ordering[::-1]:
|
| 345 |
+
func = self.funcs[sig]
|
| 346 |
+
if func.__doc__:
|
| 347 |
+
s = f'Inputs: <{str_signature(sig)}>\n'
|
| 348 |
+
s += '-' * len(s) + '\n'
|
| 349 |
+
s += func.__doc__.strip()
|
| 350 |
+
docs.append(s)
|
| 351 |
+
else:
|
| 352 |
+
other.append(str_signature(sig))
|
| 353 |
+
|
| 354 |
+
if other:
|
| 355 |
+
docs.append('Other signatures:\n ' + '\n '.join(other))
|
| 356 |
+
|
| 357 |
+
return '\n\n'.join(docs)
|
| 358 |
+
|
| 359 |
+
def _help(self, *args):
|
| 360 |
+
return self.dispatch(*map(type, args)).__doc__
|
| 361 |
+
|
| 362 |
+
def help(self, *args, **kwargs):
|
| 363 |
+
""" Print docstring for the function corresponding to inputs """
|
| 364 |
+
print(self._help(*args))
|
| 365 |
+
|
| 366 |
+
def _source(self, *args):
|
| 367 |
+
func = self.dispatch(*map(type, args))
|
| 368 |
+
if not func:
|
| 369 |
+
raise TypeError("No function found")
|
| 370 |
+
return source(func)
|
| 371 |
+
|
| 372 |
+
def source(self, *args, **kwargs):
|
| 373 |
+
""" Print source code for the function corresponding to inputs """
|
| 374 |
+
print(self._source(*args))
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def source(func):
|
| 378 |
+
s = f'File: {inspect.getsourcefile(func)}\n\n'
|
| 379 |
+
s = s + inspect.getsource(func)
|
| 380 |
+
return s
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class MethodDispatcher(Dispatcher):
|
| 384 |
+
""" Dispatch methods based on type signature
|
| 385 |
+
See Also:
|
| 386 |
+
Dispatcher
|
| 387 |
+
"""
|
| 388 |
+
__slots__ = ('obj', 'cls')
|
| 389 |
+
|
| 390 |
+
@classmethod
|
| 391 |
+
def get_func_params(cls, func):
|
| 392 |
+
if hasattr(inspect, "signature"):
|
| 393 |
+
sig = inspect.signature(func)
|
| 394 |
+
return itl.islice(sig.parameters.values(), 1, None)
|
| 395 |
+
|
| 396 |
+
def __get__(self, instance, owner):
|
| 397 |
+
self.obj = instance
|
| 398 |
+
self.cls = owner
|
| 399 |
+
return self
|
| 400 |
+
|
| 401 |
+
def __call__(self, *args, **kwargs):
|
| 402 |
+
types = tuple([type(arg) for arg in args])
|
| 403 |
+
func = self.dispatch(*types)
|
| 404 |
+
if not func:
|
| 405 |
+
raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
|
| 406 |
+
return func(self.obj, *args, **kwargs)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def str_signature(sig):
|
| 410 |
+
""" String representation of type signature
|
| 411 |
+
>>> str_signature((int, float))
|
| 412 |
+
'int, float'
|
| 413 |
+
"""
|
| 414 |
+
return ', '.join(cls.__name__ for cls in sig)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def warning_text(name, amb):
|
| 418 |
+
""" The text for ambiguity warnings """
|
| 419 |
+
text = f"\nAmbiguities exist in dispatched function {name}\n\n"
|
| 420 |
+
text += "The following signatures may result in ambiguous behavior:\n"
|
| 421 |
+
for pair in amb:
|
| 422 |
+
text += "\t" + \
|
| 423 |
+
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
|
| 424 |
+
text += "\n\nConsider making the following additions:\n\n"
|
| 425 |
+
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
|
| 426 |
+
+ f')\ndef {name}(...)' for s in amb])
|
| 427 |
+
return text
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
|
| 5 |
+
|
| 6 |
+
def raises(err, lamda):
|
| 7 |
+
try:
|
| 8 |
+
lamda()
|
| 9 |
+
return False
|
| 10 |
+
except err:
|
| 11 |
+
return True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def expand_tuples(L):
|
| 15 |
+
"""
|
| 16 |
+
>>> expand_tuples([1, (2, 3)])
|
| 17 |
+
[(1, 2), (1, 3)]
|
| 18 |
+
>>> expand_tuples([1, 2])
|
| 19 |
+
[(1, 2)]
|
| 20 |
+
"""
|
| 21 |
+
if not L:
|
| 22 |
+
return [()]
|
| 23 |
+
elif not isinstance(L[0], tuple):
|
| 24 |
+
rest = expand_tuples(L[1:])
|
| 25 |
+
return [(L[0],) + t for t in rest]
|
| 26 |
+
else:
|
| 27 |
+
rest = expand_tuples(L[1:])
|
| 28 |
+
return [(item,) + t for t in rest for item in L[0]]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Taken from theano/theano/gof/sched.py
|
| 32 |
+
# Avoids licensing issues because this was written by Matthew Rocklin
|
| 33 |
+
def _toposort(edges):
|
| 34 |
+
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
| 35 |
+
inputs:
|
| 36 |
+
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
| 37 |
+
outputs:
|
| 38 |
+
L - an ordered list of nodes that satisfy the dependencies of edges
|
| 39 |
+
>>> _toposort({1: (2, 3), 2: (3, )})
|
| 40 |
+
[1, 2, 3]
|
| 41 |
+
>>> # Closely follows the wikipedia page [2]
|
| 42 |
+
>>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
| 43 |
+
>>> # Communications of the ACM
|
| 44 |
+
>>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
| 45 |
+
"""
|
| 46 |
+
incoming_edges = reverse_dict(edges)
|
| 47 |
+
incoming_edges = OrderedDict((k, set(val))
|
| 48 |
+
for k, val in incoming_edges.items())
|
| 49 |
+
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
|
| 50 |
+
L = []
|
| 51 |
+
|
| 52 |
+
while S:
|
| 53 |
+
n, _ = S.popitem()
|
| 54 |
+
L.append(n)
|
| 55 |
+
for m in edges.get(n, ()):
|
| 56 |
+
assert n in incoming_edges[m]
|
| 57 |
+
incoming_edges[m].remove(n)
|
| 58 |
+
if not incoming_edges[m]:
|
| 59 |
+
S[m] = None
|
| 60 |
+
if any(incoming_edges.get(v, None) for v in edges):
|
| 61 |
+
raise ValueError("Input has cycles")
|
| 62 |
+
return L
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def reverse_dict(d):
|
| 66 |
+
"""Reverses direction of dependence dict
|
| 67 |
+
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
| 68 |
+
>>> reverse_dict(d) # doctest: +SKIP
|
| 69 |
+
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
| 70 |
+
:note: dict order are not deterministic. As we iterate on the
|
| 71 |
+
input dict, it make the output of this function depend on the
|
| 72 |
+
dict order. So this function output order should be considered
|
| 73 |
+
as undeterministic.
|
| 74 |
+
"""
|
| 75 |
+
result = OrderedDict() # type: ignore[var-annotated]
|
| 76 |
+
for key in d:
|
| 77 |
+
for val in d[key]:
|
| 78 |
+
result[val] = result.get(val, ()) + (key,)
|
| 79 |
+
return result
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Taken from toolz
|
| 83 |
+
# Avoids licensing issues because this version was authored by Matthew Rocklin
|
| 84 |
+
def groupby(func, seq):
|
| 85 |
+
""" Group a collection by a key function
|
| 86 |
+
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
| 87 |
+
>>> groupby(len, names) # doctest: +SKIP
|
| 88 |
+
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
| 89 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 90 |
+
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
| 91 |
+
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
| 92 |
+
See Also:
|
| 93 |
+
``countby``
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
d = OrderedDict() # type: ignore[var-annotated]
|
| 97 |
+
for item in seq:
|
| 98 |
+
key = func(item)
|
| 99 |
+
if key not in d:
|
| 100 |
+
d[key] = []
|
| 101 |
+
d[key].append(item)
|
| 102 |
+
return d
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def typename(type):
|
| 106 |
+
"""Get the name of `type`.
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
type : Union[Type, Tuple[Type]]
|
| 110 |
+
Returns
|
| 111 |
+
-------
|
| 112 |
+
str
|
| 113 |
+
The name of `type` or a tuple of the names of the types in `type`.
|
| 114 |
+
Examples
|
| 115 |
+
--------
|
| 116 |
+
>>> typename(int)
|
| 117 |
+
'int'
|
| 118 |
+
>>> typename((int, float))
|
| 119 |
+
'(int, float)'
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
return type.__name__
|
| 123 |
+
except AttributeError:
|
| 124 |
+
if len(type) == 1:
|
| 125 |
+
return typename(*type)
|
| 126 |
+
return f"({', '.join(map(typename, type))})"
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from .utils import typename
|
| 3 |
+
|
| 4 |
+
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
|
| 5 |
+
|
| 6 |
+
class VariadicSignatureType(type):
|
| 7 |
+
# checking if subclass is a subclass of self
|
| 8 |
+
def __subclasscheck__(cls, subclass):
|
| 9 |
+
other_type = (subclass.variadic_type if isvariadic(subclass)
|
| 10 |
+
else (subclass,))
|
| 11 |
+
return subclass is cls or all(
|
| 12 |
+
issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def __eq__(cls, other):
|
| 16 |
+
"""
|
| 17 |
+
Return True if other has the same variadic type
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
other : object (type)
|
| 21 |
+
The object (type) to check
|
| 22 |
+
Returns
|
| 23 |
+
-------
|
| 24 |
+
bool
|
| 25 |
+
Whether or not `other` is equal to `self`
|
| 26 |
+
"""
|
| 27 |
+
return (isvariadic(other) and
|
| 28 |
+
set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
|
| 29 |
+
|
| 30 |
+
def __hash__(cls):
|
| 31 |
+
return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def isvariadic(obj):
|
| 35 |
+
"""Check whether the type `obj` is variadic.
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
obj : type
|
| 39 |
+
The type to check
|
| 40 |
+
Returns
|
| 41 |
+
-------
|
| 42 |
+
bool
|
| 43 |
+
Whether or not `obj` is variadic
|
| 44 |
+
Examples
|
| 45 |
+
--------
|
| 46 |
+
>>> # xdoctest: +SKIP
|
| 47 |
+
>>> isvariadic(int)
|
| 48 |
+
False
|
| 49 |
+
>>> isvariadic(Variadic[int])
|
| 50 |
+
True
|
| 51 |
+
"""
|
| 52 |
+
return isinstance(obj, VariadicSignatureType)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class VariadicSignatureMeta(type):
|
| 56 |
+
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
|
| 57 |
+
generate a new type for Variadic signatures. See the Variadic class for
|
| 58 |
+
examples of how this behaves.
|
| 59 |
+
"""
|
| 60 |
+
def __getitem__(cls, variadic_type):
|
| 61 |
+
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
|
| 62 |
+
raise ValueError("Variadic types must be type or tuple of types"
|
| 63 |
+
" (Variadic[int] or Variadic[(int, float)]")
|
| 64 |
+
|
| 65 |
+
if not isinstance(variadic_type, tuple):
|
| 66 |
+
variadic_type = variadic_type,
|
| 67 |
+
return VariadicSignatureType(
|
| 68 |
+
f'Variadic[{typename(variadic_type)}]',
|
| 69 |
+
(),
|
| 70 |
+
dict(variadic_type=variadic_type, __slots__=())
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Variadic(metaclass=VariadicSignatureMeta):
|
| 75 |
+
"""A class whose getitem method can be used to generate a new type
|
| 76 |
+
representing a specific variadic signature.
|
| 77 |
+
Examples
|
| 78 |
+
--------
|
| 79 |
+
>>> # xdoctest: +SKIP
|
| 80 |
+
>>> Variadic[int] # any number of int arguments
|
| 81 |
+
<class 'multipledispatch.variadic.Variadic[int]'>
|
| 82 |
+
>>> Variadic[(int, str)] # any number of one of int or str arguments
|
| 83 |
+
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
|
| 84 |
+
>>> issubclass(int, Variadic[int])
|
| 85 |
+
True
|
| 86 |
+
>>> issubclass(int, Variadic[(int, str)])
|
| 87 |
+
True
|
| 88 |
+
>>> issubclass(str, Variadic[(int, str)])
|
| 89 |
+
True
|
| 90 |
+
>>> issubclass(float, Variadic[(int, str)])
|
| 91 |
+
False
|
| 92 |
+
"""
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
import collections
|
| 3 |
+
import operator
|
| 4 |
+
from functools import reduce
|
| 5 |
+
from collections.abc import Mapping
|
| 6 |
+
|
| 7 |
+
__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
|
| 8 |
+
'valfilter', 'keyfilter', 'itemfilter',
|
| 9 |
+
'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_factory(f, kwargs):
|
| 13 |
+
factory = kwargs.pop('factory', dict)
|
| 14 |
+
if kwargs:
|
| 15 |
+
raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
|
| 16 |
+
return factory
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def merge(*dicts, **kwargs):
|
| 20 |
+
""" Merge a collection of dictionaries
|
| 21 |
+
|
| 22 |
+
>>> merge({1: 'one'}, {2: 'two'})
|
| 23 |
+
{1: 'one', 2: 'two'}
|
| 24 |
+
|
| 25 |
+
Later dictionaries have precedence
|
| 26 |
+
|
| 27 |
+
>>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
|
| 28 |
+
{1: 2, 3: 3, 4: 4}
|
| 29 |
+
|
| 30 |
+
See Also:
|
| 31 |
+
merge_with
|
| 32 |
+
"""
|
| 33 |
+
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
| 34 |
+
dicts = dicts[0]
|
| 35 |
+
factory = _get_factory(merge, kwargs)
|
| 36 |
+
|
| 37 |
+
rv = factory()
|
| 38 |
+
for d in dicts:
|
| 39 |
+
rv.update(d)
|
| 40 |
+
return rv
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def merge_with(func, *dicts, **kwargs):
|
| 44 |
+
""" Merge dictionaries and apply function to combined values
|
| 45 |
+
|
| 46 |
+
A key may occur in more than one dict, and all values mapped from the key
|
| 47 |
+
will be passed to the function as a list, such as func([val1, val2, ...]).
|
| 48 |
+
|
| 49 |
+
>>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
|
| 50 |
+
{1: 11, 2: 22}
|
| 51 |
+
|
| 52 |
+
>>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
|
| 53 |
+
{1: 1, 2: 2, 3: 30}
|
| 54 |
+
|
| 55 |
+
See Also:
|
| 56 |
+
merge
|
| 57 |
+
"""
|
| 58 |
+
if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
|
| 59 |
+
dicts = dicts[0]
|
| 60 |
+
factory = _get_factory(merge_with, kwargs)
|
| 61 |
+
|
| 62 |
+
result = factory()
|
| 63 |
+
for d in dicts:
|
| 64 |
+
for k, v in d.items():
|
| 65 |
+
if k not in result:
|
| 66 |
+
result[k] = [v]
|
| 67 |
+
else:
|
| 68 |
+
result[k].append(v)
|
| 69 |
+
return valmap(func, result, factory)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def valmap(func, d, factory=dict):
|
| 73 |
+
""" Apply function to values of dictionary
|
| 74 |
+
|
| 75 |
+
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
| 76 |
+
>>> valmap(sum, bills) # doctest: +SKIP
|
| 77 |
+
{'Alice': 65, 'Bob': 45}
|
| 78 |
+
|
| 79 |
+
See Also:
|
| 80 |
+
keymap
|
| 81 |
+
itemmap
|
| 82 |
+
"""
|
| 83 |
+
rv = factory()
|
| 84 |
+
rv.update(zip(d.keys(), map(func, d.values())))
|
| 85 |
+
return rv
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def keymap(func, d, factory=dict):
|
| 89 |
+
""" Apply function to keys of dictionary
|
| 90 |
+
|
| 91 |
+
>>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
|
| 92 |
+
>>> keymap(str.lower, bills) # doctest: +SKIP
|
| 93 |
+
{'alice': [20, 15, 30], 'bob': [10, 35]}
|
| 94 |
+
|
| 95 |
+
See Also:
|
| 96 |
+
valmap
|
| 97 |
+
itemmap
|
| 98 |
+
"""
|
| 99 |
+
rv = factory()
|
| 100 |
+
rv.update(zip(map(func, d.keys()), d.values()))
|
| 101 |
+
return rv
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def itemmap(func, d, factory=dict):
|
| 105 |
+
""" Apply function to items of dictionary
|
| 106 |
+
|
| 107 |
+
>>> accountids = {"Alice": 10, "Bob": 20}
|
| 108 |
+
>>> itemmap(reversed, accountids) # doctest: +SKIP
|
| 109 |
+
{10: "Alice", 20: "Bob"}
|
| 110 |
+
|
| 111 |
+
See Also:
|
| 112 |
+
keymap
|
| 113 |
+
valmap
|
| 114 |
+
"""
|
| 115 |
+
rv = factory()
|
| 116 |
+
rv.update(map(func, d.items()))
|
| 117 |
+
return rv
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def valfilter(predicate, d, factory=dict):
|
| 121 |
+
""" Filter items in dictionary by value
|
| 122 |
+
|
| 123 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 124 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 125 |
+
>>> valfilter(iseven, d)
|
| 126 |
+
{1: 2, 3: 4}
|
| 127 |
+
|
| 128 |
+
See Also:
|
| 129 |
+
keyfilter
|
| 130 |
+
itemfilter
|
| 131 |
+
valmap
|
| 132 |
+
"""
|
| 133 |
+
rv = factory()
|
| 134 |
+
for k, v in d.items():
|
| 135 |
+
if predicate(v):
|
| 136 |
+
rv[k] = v
|
| 137 |
+
return rv
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def keyfilter(predicate, d, factory=dict):
|
| 141 |
+
""" Filter items in dictionary by key
|
| 142 |
+
|
| 143 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 144 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 145 |
+
>>> keyfilter(iseven, d)
|
| 146 |
+
{2: 3, 4: 5}
|
| 147 |
+
|
| 148 |
+
See Also:
|
| 149 |
+
valfilter
|
| 150 |
+
itemfilter
|
| 151 |
+
keymap
|
| 152 |
+
"""
|
| 153 |
+
rv = factory()
|
| 154 |
+
for k, v in d.items():
|
| 155 |
+
if predicate(k):
|
| 156 |
+
rv[k] = v
|
| 157 |
+
return rv
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def itemfilter(predicate, d, factory=dict):
|
| 161 |
+
""" Filter items in dictionary by item
|
| 162 |
+
|
| 163 |
+
>>> def isvalid(item):
|
| 164 |
+
... k, v = item
|
| 165 |
+
... return k % 2 == 0 and v < 4
|
| 166 |
+
|
| 167 |
+
>>> d = {1: 2, 2: 3, 3: 4, 4: 5}
|
| 168 |
+
>>> itemfilter(isvalid, d)
|
| 169 |
+
{2: 3}
|
| 170 |
+
|
| 171 |
+
See Also:
|
| 172 |
+
keyfilter
|
| 173 |
+
valfilter
|
| 174 |
+
itemmap
|
| 175 |
+
"""
|
| 176 |
+
rv = factory()
|
| 177 |
+
for item in d.items():
|
| 178 |
+
if predicate(item):
|
| 179 |
+
k, v = item
|
| 180 |
+
rv[k] = v
|
| 181 |
+
return rv
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def assoc(d, key, value, factory=dict):
|
| 185 |
+
""" Return a new dict with new key value pair
|
| 186 |
+
|
| 187 |
+
New dict has d[key] set to value. Does not modify the initial dictionary.
|
| 188 |
+
|
| 189 |
+
>>> assoc({'x': 1}, 'x', 2)
|
| 190 |
+
{'x': 2}
|
| 191 |
+
>>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
|
| 192 |
+
{'x': 1, 'y': 3}
|
| 193 |
+
"""
|
| 194 |
+
d2 = factory()
|
| 195 |
+
d2.update(d)
|
| 196 |
+
d2[key] = value
|
| 197 |
+
return d2
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def dissoc(d, *keys, **kwargs):
|
| 201 |
+
""" Return a new dict with the given key(s) removed.
|
| 202 |
+
|
| 203 |
+
New dict has d[key] deleted for each supplied key.
|
| 204 |
+
Does not modify the initial dictionary.
|
| 205 |
+
|
| 206 |
+
>>> dissoc({'x': 1, 'y': 2}, 'y')
|
| 207 |
+
{'x': 1}
|
| 208 |
+
>>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
|
| 209 |
+
{}
|
| 210 |
+
>>> dissoc({'x': 1}, 'y') # Ignores missing keys
|
| 211 |
+
{'x': 1}
|
| 212 |
+
"""
|
| 213 |
+
factory = _get_factory(dissoc, kwargs)
|
| 214 |
+
d2 = factory()
|
| 215 |
+
|
| 216 |
+
if len(keys) < len(d) * .6:
|
| 217 |
+
d2.update(d)
|
| 218 |
+
for key in keys:
|
| 219 |
+
if key in d2:
|
| 220 |
+
del d2[key]
|
| 221 |
+
else:
|
| 222 |
+
remaining = set(d)
|
| 223 |
+
remaining.difference_update(keys)
|
| 224 |
+
for k in remaining:
|
| 225 |
+
d2[k] = d[k]
|
| 226 |
+
return d2
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def assoc_in(d, keys, value, factory=dict):
|
| 230 |
+
""" Return a new dict with new, potentially nested, key value pair
|
| 231 |
+
|
| 232 |
+
>>> purchase = {'name': 'Alice',
|
| 233 |
+
... 'order': {'items': ['Apple', 'Orange'],
|
| 234 |
+
... 'costs': [0.50, 1.25]},
|
| 235 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 236 |
+
>>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
|
| 237 |
+
{'credit card': '5555-1234-1234-1234',
|
| 238 |
+
'name': 'Alice',
|
| 239 |
+
'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
|
| 240 |
+
"""
|
| 241 |
+
return update_in(d, keys, lambda x: value, value, factory)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def update_in(d, keys, func, default=None, factory=dict):
|
| 245 |
+
""" Update value in a (potentially) nested dictionary
|
| 246 |
+
|
| 247 |
+
inputs:
|
| 248 |
+
d - dictionary on which to operate
|
| 249 |
+
keys - list or tuple giving the location of the value to be changed in d
|
| 250 |
+
func - function to operate on that value
|
| 251 |
+
|
| 252 |
+
If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
|
| 253 |
+
original dictionary with v replaced by func(v), but does not mutate the
|
| 254 |
+
original dictionary.
|
| 255 |
+
|
| 256 |
+
If k0 is not a key in d, update_in creates nested dictionaries to the depth
|
| 257 |
+
specified by the keys, with the innermost value set to func(default).
|
| 258 |
+
|
| 259 |
+
>>> inc = lambda x: x + 1
|
| 260 |
+
>>> update_in({'a': 0}, ['a'], inc)
|
| 261 |
+
{'a': 1}
|
| 262 |
+
|
| 263 |
+
>>> transaction = {'name': 'Alice',
|
| 264 |
+
... 'purchase': {'items': ['Apple', 'Orange'],
|
| 265 |
+
... 'costs': [0.50, 1.25]},
|
| 266 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 267 |
+
>>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
|
| 268 |
+
{'credit card': '5555-1234-1234-1234',
|
| 269 |
+
'name': 'Alice',
|
| 270 |
+
'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
|
| 271 |
+
|
| 272 |
+
>>> # updating a value when k0 is not in d
|
| 273 |
+
>>> update_in({}, [1, 2, 3], str, default="bar")
|
| 274 |
+
{1: {2: {3: 'bar'}}}
|
| 275 |
+
>>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
|
| 276 |
+
{1: 'foo', 2: {3: {4: 1}}}
|
| 277 |
+
"""
|
| 278 |
+
ks = iter(keys)
|
| 279 |
+
k = next(ks)
|
| 280 |
+
|
| 281 |
+
rv = inner = factory()
|
| 282 |
+
rv.update(d)
|
| 283 |
+
|
| 284 |
+
for key in ks:
|
| 285 |
+
if k in d:
|
| 286 |
+
d = d[k]
|
| 287 |
+
dtemp = factory()
|
| 288 |
+
dtemp.update(d)
|
| 289 |
+
else:
|
| 290 |
+
d = dtemp = factory()
|
| 291 |
+
|
| 292 |
+
inner[k] = inner = dtemp
|
| 293 |
+
k = key
|
| 294 |
+
|
| 295 |
+
if k in d:
|
| 296 |
+
inner[k] = func(d[k])
|
| 297 |
+
else:
|
| 298 |
+
inner[k] = func(default)
|
| 299 |
+
return rv
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def get_in(keys, coll, default=None, no_default=False):
|
| 303 |
+
""" Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
|
| 304 |
+
|
| 305 |
+
If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
|
| 306 |
+
``no_default`` is specified, then it raises KeyError or IndexError.
|
| 307 |
+
|
| 308 |
+
``get_in`` is a generalization of ``operator.getitem`` for nested data
|
| 309 |
+
structures such as dictionaries and lists.
|
| 310 |
+
|
| 311 |
+
>>> transaction = {'name': 'Alice',
|
| 312 |
+
... 'purchase': {'items': ['Apple', 'Orange'],
|
| 313 |
+
... 'costs': [0.50, 1.25]},
|
| 314 |
+
... 'credit card': '5555-1234-1234-1234'}
|
| 315 |
+
>>> get_in(['purchase', 'items', 0], transaction)
|
| 316 |
+
'Apple'
|
| 317 |
+
>>> get_in(['name'], transaction)
|
| 318 |
+
'Alice'
|
| 319 |
+
>>> get_in(['purchase', 'total'], transaction)
|
| 320 |
+
>>> get_in(['purchase', 'items', 'apple'], transaction)
|
| 321 |
+
>>> get_in(['purchase', 'items', 10], transaction)
|
| 322 |
+
>>> get_in(['purchase', 'total'], transaction, 0)
|
| 323 |
+
0
|
| 324 |
+
>>> get_in(['y'], {}, no_default=True)
|
| 325 |
+
Traceback (most recent call last):
|
| 326 |
+
...
|
| 327 |
+
KeyError: 'y'
|
| 328 |
+
|
| 329 |
+
See Also:
|
| 330 |
+
itertoolz.get
|
| 331 |
+
operator.getitem
|
| 332 |
+
"""
|
| 333 |
+
try:
|
| 334 |
+
return reduce(operator.getitem, keys, coll)
|
| 335 |
+
except (KeyError, IndexError, TypeError):
|
| 336 |
+
if no_default:
|
| 337 |
+
raise
|
| 338 |
+
return default
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def getter(index):
|
| 342 |
+
if isinstance(index, list):
|
| 343 |
+
if len(index) == 1:
|
| 344 |
+
index = index[0]
|
| 345 |
+
return lambda x: (x[index],)
|
| 346 |
+
elif index:
|
| 347 |
+
return operator.itemgetter(*index)
|
| 348 |
+
else:
|
| 349 |
+
return lambda x: ()
|
| 350 |
+
else:
|
| 351 |
+
return operator.itemgetter(index)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def groupby(key, seq):
|
| 355 |
+
""" Group a collection by a key function
|
| 356 |
+
|
| 357 |
+
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
|
| 358 |
+
>>> groupby(len, names) # doctest: +SKIP
|
| 359 |
+
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
|
| 360 |
+
|
| 361 |
+
>>> iseven = lambda x: x % 2 == 0
|
| 362 |
+
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
|
| 363 |
+
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
|
| 364 |
+
|
| 365 |
+
Non-callable keys imply grouping on a member.
|
| 366 |
+
|
| 367 |
+
>>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
|
| 368 |
+
... {'name': 'Bob', 'gender': 'M'},
|
| 369 |
+
... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
|
| 370 |
+
{'F': [{'gender': 'F', 'name': 'Alice'}],
|
| 371 |
+
'M': [{'gender': 'M', 'name': 'Bob'},
|
| 372 |
+
{'gender': 'M', 'name': 'Charlie'}]}
|
| 373 |
+
|
| 374 |
+
Not to be confused with ``itertools.groupby``
|
| 375 |
+
|
| 376 |
+
See Also:
|
| 377 |
+
countby
|
| 378 |
+
"""
|
| 379 |
+
if not callable(key):
|
| 380 |
+
key = getter(key)
|
| 381 |
+
d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
|
| 382 |
+
for item in seq:
|
| 383 |
+
d[key(item)](item)
|
| 384 |
+
rv = {}
|
| 385 |
+
for k, v in d.items():
|
| 386 |
+
rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
|
| 387 |
+
return rv
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def first(seq):
|
| 391 |
+
""" The first element in a sequence
|
| 392 |
+
|
| 393 |
+
>>> first('ABC')
|
| 394 |
+
'A'
|
| 395 |
+
"""
|
| 396 |
+
return next(iter(seq))
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
|
| 3 |
+
def hashable(x):
|
| 4 |
+
try:
|
| 5 |
+
hash(x)
|
| 6 |
+
return True
|
| 7 |
+
except TypeError:
|
| 8 |
+
return False
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def transitive_get(key, d):
|
| 12 |
+
""" Transitive dict.get
|
| 13 |
+
>>> d = {1: 2, 2: 3, 3: 4}
|
| 14 |
+
>>> d.get(1)
|
| 15 |
+
2
|
| 16 |
+
>>> transitive_get(1, d)
|
| 17 |
+
4
|
| 18 |
+
"""
|
| 19 |
+
while hashable(key) and key in d:
|
| 20 |
+
key = d[key]
|
| 21 |
+
return key
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def raises(err, lamda):
|
| 25 |
+
try:
|
| 26 |
+
lamda()
|
| 27 |
+
return False
|
| 28 |
+
except err:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Taken from theano/theano/gof/sched.py
|
| 33 |
+
# Avoids licensing issues because this was written by Matthew Rocklin
|
| 34 |
+
def _toposort(edges):
|
| 35 |
+
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
|
| 36 |
+
inputs:
|
| 37 |
+
edges - a dict of the form {a: {b, c}} where b and c depend on a
|
| 38 |
+
outputs:
|
| 39 |
+
L - an ordered list of nodes that satisfy the dependencies of edges
|
| 40 |
+
>>> # xdoctest: +SKIP
|
| 41 |
+
>>> _toposort({1: (2, 3), 2: (3, )})
|
| 42 |
+
[1, 2, 3]
|
| 43 |
+
Closely follows the wikipedia page [2]
|
| 44 |
+
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
|
| 45 |
+
Communications of the ACM
|
| 46 |
+
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
|
| 47 |
+
"""
|
| 48 |
+
incoming_edges = reverse_dict(edges)
|
| 49 |
+
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
|
| 50 |
+
S = ({v for v in edges if v not in incoming_edges})
|
| 51 |
+
L = []
|
| 52 |
+
|
| 53 |
+
while S:
|
| 54 |
+
n = S.pop()
|
| 55 |
+
L.append(n)
|
| 56 |
+
for m in edges.get(n, ()):
|
| 57 |
+
assert n in incoming_edges[m]
|
| 58 |
+
incoming_edges[m].remove(n)
|
| 59 |
+
if not incoming_edges[m]:
|
| 60 |
+
S.add(m)
|
| 61 |
+
if any(incoming_edges.get(v, None) for v in edges):
|
| 62 |
+
raise ValueError("Input has cycles")
|
| 63 |
+
return L
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def reverse_dict(d):
|
| 67 |
+
"""Reverses direction of dependence dict
|
| 68 |
+
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
|
| 69 |
+
>>> reverse_dict(d) # doctest: +SKIP
|
| 70 |
+
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
|
| 71 |
+
:note: dict order are not deterministic. As we iterate on the
|
| 72 |
+
input dict, it make the output of this function depend on the
|
| 73 |
+
dict order. So this function output order should be considered
|
| 74 |
+
as undeterministic.
|
| 75 |
+
"""
|
| 76 |
+
result = {} # type: ignore[var-annotated]
|
| 77 |
+
for key in d:
|
| 78 |
+
for val in d[key]:
|
| 79 |
+
result[val] = result.get(val, ()) + (key,)
|
| 80 |
+
return result
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def xfail(func):
|
| 84 |
+
try:
|
| 85 |
+
func()
|
| 86 |
+
raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
|
| 87 |
+
except Exception:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def freeze(d):
|
| 92 |
+
""" Freeze container to hashable form
|
| 93 |
+
>>> freeze(1)
|
| 94 |
+
1
|
| 95 |
+
>>> freeze([1, 2])
|
| 96 |
+
(1, 2)
|
| 97 |
+
>>> freeze({1: 2}) # doctest: +SKIP
|
| 98 |
+
frozenset([(1, 2)])
|
| 99 |
+
"""
|
| 100 |
+
if isinstance(d, dict):
|
| 101 |
+
return frozenset(map(freeze, d.items()))
|
| 102 |
+
if isinstance(d, set):
|
| 103 |
+
return frozenset(map(freeze, d))
|
| 104 |
+
if isinstance(d, (tuple, list)):
|
| 105 |
+
return tuple(map(freeze, d))
|
| 106 |
+
return d
|
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# mypy: allow-untyped-defs
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
from .utils import hashable
|
| 4 |
+
from .dispatch import dispatch
|
| 5 |
+
|
| 6 |
+
_global_logic_variables = set() # type: ignore[var-annotated]
|
| 7 |
+
_glv = _global_logic_variables
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Var:
|
| 11 |
+
""" Logic Variable """
|
| 12 |
+
|
| 13 |
+
_id = 1
|
| 14 |
+
|
| 15 |
+
def __new__(cls, *token):
|
| 16 |
+
if len(token) == 0:
|
| 17 |
+
token = f"_{Var._id}" # type: ignore[assignment]
|
| 18 |
+
Var._id += 1
|
| 19 |
+
elif len(token) == 1:
|
| 20 |
+
token = token[0]
|
| 21 |
+
|
| 22 |
+
obj = object.__new__(cls)
|
| 23 |
+
obj.token = token # type: ignore[attr-defined]
|
| 24 |
+
return obj
|
| 25 |
+
|
| 26 |
+
def __str__(self):
|
| 27 |
+
return "~" + str(self.token) # type: ignore[attr-defined]
|
| 28 |
+
__repr__ = __str__
|
| 29 |
+
|
| 30 |
+
def __eq__(self, other):
|
| 31 |
+
return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
|
| 32 |
+
|
| 33 |
+
def __hash__(self):
|
| 34 |
+
return hash((type(self), self.token)) # type: ignore[attr-defined]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def var():
|
| 38 |
+
return lambda *args: Var(*args)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def vars():
|
| 42 |
+
return lambda n: [var() for i in range(n)]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dispatch(Var)
|
| 46 |
+
def isvar(v):
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
isvar
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dispatch(object) # type: ignore[no-redef]
|
| 53 |
+
def isvar(o):
|
| 54 |
+
return not not _glv and hashable(o) and o in _glv
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@contextmanager
|
| 58 |
+
def variables(*variables):
|
| 59 |
+
"""
|
| 60 |
+
Context manager for logic variables
|
| 61 |
+
|
| 62 |
+
Example:
|
| 63 |
+
>>> # xdoctest: +SKIP("undefined vars")
|
| 64 |
+
>>> from __future__ import with_statement
|
| 65 |
+
>>> with variables(1):
|
| 66 |
+
... print(isvar(1))
|
| 67 |
+
True
|
| 68 |
+
>>> print(isvar(1))
|
| 69 |
+
False
|
| 70 |
+
>>> # Normal approach
|
| 71 |
+
>>> from unification import unify
|
| 72 |
+
>>> x = var('x')
|
| 73 |
+
>>> unify(x, 1)
|
| 74 |
+
{~x: 1}
|
| 75 |
+
>>> # Context Manager approach
|
| 76 |
+
>>> with variables('x'):
|
| 77 |
+
... print(unify('x', 1))
|
| 78 |
+
{'x': 1}
|
| 79 |
+
"""
|
| 80 |
+
old_global_logic_variables = _global_logic_variables.copy()
|
| 81 |
+
_global_logic_variables.update(set(variables))
|
| 82 |
+
try:
|
| 83 |
+
yield
|
| 84 |
+
finally:
|
| 85 |
+
_global_logic_variables.clear()
|
| 86 |
+
_global_logic_variables.update(old_global_logic_variables)
|
.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import graph_drawer
|
| 2 |
+
from . import graph_manipulation
|
| 3 |
+
from . import net_min_base
|
| 4 |
+
from . import operator_support
|
| 5 |
+
from . import param_fetch
|
| 6 |
+
from . import reinplace
|
| 7 |
+
from . import runtime_assert
|
| 8 |
+
from . import shape_prop
|
| 9 |
+
from . import split_module
|
| 10 |
+
from . import split_utils
|
| 11 |
+
from . import splitter_base
|
| 12 |
+
from . import tools_common
|
.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import operator
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
|
| 7 |
+
"""
|
| 8 |
+
Annotate the type of getitem nodes, inferred from the type of sequence node.
|
| 9 |
+
If sequence node is not annotated with a type, do nothing.
|
| 10 |
+
Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
|
| 11 |
+
|
| 12 |
+
This is helpful since annotations on local names within function are lost during FX transforms.
|
| 13 |
+
Adding back known type annotation for getitem nodes to improve jit scriptability.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
graph (Graph): The graph to be annotated
|
| 17 |
+
"""
|
| 18 |
+
for node in graph.nodes:
|
| 19 |
+
if node.target == operator.getitem:
|
| 20 |
+
sequence_node, index_node = node.args
|
| 21 |
+
if not sequence_node.type:
|
| 22 |
+
continue
|
| 23 |
+
# container types
|
| 24 |
+
if hasattr(sequence_node.type, "_name"):
|
| 25 |
+
parameterized_types = sequence_node.type.__args__
|
| 26 |
+
if sequence_node.type._name == "Tuple":
|
| 27 |
+
if len(parameterized_types) == 2 and isinstance(
|
| 28 |
+
parameterized_types[1], type(...)
|
| 29 |
+
):
|
| 30 |
+
node.type = parameterized_types[0]
|
| 31 |
+
else:
|
| 32 |
+
assert len(parameterized_types) > index_node
|
| 33 |
+
node_type = parameterized_types[index_node]
|
| 34 |
+
node.type = node_type
|
| 35 |
+
elif sequence_node.type._name == "List":
|
| 36 |
+
assert len(parameterized_types) == 1
|
| 37 |
+
node.type = parameterized_types[0]
|
| 38 |
+
# NamedTuple type
|
| 39 |
+
elif hasattr(sequence_node.type, "__annotations__"):
|
| 40 |
+
if sequence_node.type == torch.Tensor:
|
| 41 |
+
continue
|
| 42 |
+
sequence_node_field_types = sequence_node.type.__annotations__
|
| 43 |
+
field_name = sequence_node.type._fields[index_node]
|
| 44 |
+
node.type = sequence_node_field_types[field_name]
|
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py
ADDED
|
File without changes
|