koichi12 commited on
Commit
d716663
·
verified ·
1 Parent(s): 1bc1bad

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
  2. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py +0 -0
  3. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc +0 -0
  4. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc +0 -0
  8. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc +0 -0
  9. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc +0 -0
  10. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py +558 -0
  12. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +1281 -0
  13. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +1040 -0
  14. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py +14 -0
  15. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +349 -0
  16. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py +53 -0
  17. .venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py +29 -0
  18. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py +4 -0
  19. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc +0 -0
  24. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc +0 -0
  25. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc +0 -0
  26. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc +0 -0
  27. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py +119 -0
  28. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py +6 -0
  29. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py +122 -0
  30. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py +118 -0
  31. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py +3 -0
  32. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc +0 -0
  33. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py +121 -0
  39. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py +84 -0
  40. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py +427 -0
  41. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py +126 -0
  42. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py +92 -0
  43. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py +396 -0
  44. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py +106 -0
  45. .venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py +86 -0
  46. .venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py +12 -0
  47. .venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py +44 -0
  48. .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py +0 -0
  49. .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py +0 -0
.venv/lib/python3.11/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (194 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (216 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-311.pyc ADDED
Binary file (28.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-311.pyc ADDED
Binary file (72.3 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-311.pyc ADDED
Binary file (52.1 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-311.pyc ADDED
Binary file (521 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-311.pyc ADDED
Binary file (2.43 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-311.pyc ADDED
Binary file (1.53 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \
3
+ op_mod, op_gt, op_lt, op_neq, op_eq
4
+ from torch.fx.tensor_type import TensorType, Dyn
5
+
6
+
7
+ class Constraint:
8
+ pass
9
+
10
+
11
+ class Conj(Constraint):
12
+ def __init__(self, conjuncts):
13
+ """
14
+ :param conjuncts: Conjunction of constraints
15
+ """
16
+ self.conjucts = conjuncts
17
+
18
+ def __eq__(self, other):
19
+ if isinstance(other, Conj):
20
+ return self.conjucts == other.conjucts and self.conjucts == other.conjucts
21
+ else:
22
+ return False
23
+
24
+ def __repr__(self):
25
+ return f'And({self.conjucts})'
26
+
27
+
28
+ class Disj(Constraint):
29
+ def __init__(self, disjuncts):
30
+ """
31
+ :param disjuncts: Disjunction of constraints
32
+ """
33
+ self.disjuncts = disjuncts
34
+
35
+ def __eq__(self, other):
36
+ if isinstance(other, Disj):
37
+ return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts
38
+ else:
39
+ return False
40
+
41
+ def __repr__(self):
42
+ return f'Or({self.disjuncts})'
43
+
44
+
45
+ class Prod(Constraint):
46
+ def __init__(self, products):
47
+ """
48
+ :param products: lists of dimensions to multiply
49
+ """
50
+ self.products = products
51
+
52
+ def __eq__(self, other):
53
+ if isinstance(other, Prod):
54
+ return self.products == other.products and self.products == other.products
55
+ else:
56
+ return False
57
+
58
+ def __repr__(self):
59
+ return f'Product({self.products})'
60
+
61
+
62
+ class T(Constraint):
63
+ """
64
+ True
65
+ """
66
+ def __init__(self) -> None:
67
+ pass
68
+
69
+ def __eq__(self, other):
70
+ return isinstance(other, T)
71
+
72
+ def __repr__(self):
73
+ return 'True'
74
+
75
+ class F(Constraint):
76
+ """
77
+ False
78
+ """
79
+ def __init__(self) -> None:
80
+ pass
81
+
82
+ def __eq__(self, other):
83
+ return isinstance(other, F)
84
+
85
+ def __repr__(self):
86
+ return 'False'
87
+
88
+
89
+ class BinaryConstraint(Constraint):
90
+ """
91
+ Represents all binary operations
92
+ """
93
+ def __init__(self, lhs, rhs, op):
94
+ """
95
+ :param lhs: lhs of the constraint
96
+ :param rhs: rhs of the constraint
97
+ :param op: string representing the operation
98
+ """
99
+ self.lhs = lhs
100
+ self.rhs = rhs
101
+ self.op = op
102
+
103
+ def __eq__(self, other):
104
+ if isinstance(other, BinaryConstraint):
105
+ return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op
106
+ else:
107
+ return False
108
+
109
+ def __repr__(self):
110
+ return f'({self.lhs} {self.op} {self.rhs})'
111
+
112
+
113
+ class BinConstraintT(BinaryConstraint):
114
+ """
115
+ Binary constraints about tensors
116
+ """
117
+ def __init__(self, lhs, rhs, op):
118
+ assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \
119
+ (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn)
120
+ super().__init__(lhs, rhs, op)
121
+
122
+ def __eq__(self, other):
123
+ return super().__eq__(other)
124
+
125
+
126
+ class BinConstraintD(BinaryConstraint):
127
+ """
128
+ Binary constraints about dimensions
129
+ """
130
+ def __init__(self, lhs, rhs, op):
131
+ assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)
132
+ assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)
133
+
134
+ super().__init__(lhs, rhs, op)
135
+
136
+ def __eq__(self, other):
137
+ return super().__eq__(other)
138
+
139
+
140
+
141
+ class TGreatestUpperBound(Constraint):
142
+ """
143
+ Greatest Upper bound for tensors with dynamic type
144
+ """
145
+ def __init__(self, res, rhs1, rhs2):
146
+ """
147
+ :param res: tensor variable that stores the result of the outout
148
+ :param rhs1: tensor or tensor variable
149
+ :param rhs2: tensor or tensor variabke
150
+ """
151
+ self.res = res
152
+ self.rhs1 = rhs1
153
+ self.rhs2 = rhs2
154
+
155
+ def __repr__(self):
156
+ return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}'
157
+
158
+ def __eq__(self, other):
159
+ if isinstance(other, TGreatestUpperBound):
160
+ return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
161
+ else:
162
+ return False
163
+
164
+
165
+ class DGreatestUpperBound(Constraint):
166
+ """
167
+ Greatest Upper bound for dimensions
168
+ """
169
+ def __init__(self, res, rhs1, rhs2):
170
+ """
171
+ :param res: Dimension variable to store the result
172
+ :param rhs1: dimension variable 1
173
+ :param rhs2: dimension variable 2
174
+ """
175
+ assert is_dim(res)
176
+ assert is_dim(rhs1)
177
+ assert is_dim(rhs2)
178
+
179
+ self.res = res
180
+ self.rhs1 = rhs1
181
+ self.rhs2 = rhs2
182
+
183
+ def __repr__(self):
184
+ return f'{self.res} = {self.rhs1}\u2294{self.rhs2}'
185
+
186
+ def __eq__(self, other):
187
+ if isinstance(other, DGreatestUpperBound):
188
+ return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2
189
+ else:
190
+ return False
191
+
192
+
193
+ class CanReshape(Constraint):
194
+ """
195
+ can_reshape constraint
196
+ """
197
+ def __init__(self, src, target):
198
+ """
199
+ :param src: tensor variable
200
+ :param target: tensor
201
+ """
202
+ self.src = src
203
+ self.target = target
204
+
205
+ def __repr__(self):
206
+ return f'can-reshape({self.src}, {self.target})'
207
+
208
+ def __eq__(self, other):
209
+ if isinstance(other, CanReshape):
210
+ return self.src == other.src and self.target == other.target
211
+ else:
212
+ return False
213
+
214
+
215
+ class IndexSelect(Constraint):
216
+
217
+ def __init__(self, tensor_size, input_var, dim_replace, index, output):
218
+ """
219
+ Args:
220
+ input_var: input to index_select
221
+ tensor_size: tensor size we are considering
222
+ dim_replace: the dimension of the output at "index"
223
+ index: location of the dimensions to replace in the input
224
+ output: variable to store the result
225
+ """
226
+ assert isinstance(input_var, TVar)
227
+ assert isinstance(output, TVar)
228
+ assert isinstance(dim_replace, DVar) or dim_replace == Dyn
229
+ assert isinstance(index, int)
230
+
231
+ self.input_var = input_var
232
+ self.tensor_size = tensor_size
233
+ self.dim_replace = dim_replace
234
+ self.index = index
235
+ self.output = output
236
+
237
+ def __repr__(self):
238
+
239
+ return f' {self.output} = ' \
240
+ f'IndexSelect({self.input_var}, ' \
241
+ f'tensor_size: {self.tensor_size}, ' \
242
+ f'{self.dim_replace}, ' \
243
+ f'{self.index})'
244
+
245
+ def __eq__(self, other):
246
+ if isinstance(other, IndexSelect):
247
+ return self.tensor_size == other.tensor_size and \
248
+ self.dim_replace == other.dim_replace and \
249
+ self.index == other.index and \
250
+ self.output == other.output and \
251
+ self.input_var == other.input_var
252
+ else:
253
+ return False
254
+
255
+
256
+ class Transpose(Constraint):
257
+
258
+ def __init__(self, tensor_size, input_var, index1, index2, output):
259
+ """
260
+ Args:
261
+ tensor_size: current tensor size
262
+ input_var: variable to hold input
263
+ index1: dimension 1
264
+ index2: dimension 2
265
+ output: output that stores result
266
+ """
267
+ assert isinstance(input_var, TVar)
268
+ assert isinstance(output, TVar)
269
+ assert isinstance(index1, int)
270
+ assert isinstance(index2, int)
271
+
272
+ self.input_var = input_var
273
+ self.tensor_size = tensor_size
274
+ self.index1 = index1
275
+ self.index2 = index2
276
+ self.output = output
277
+
278
+ def __repr__(self):
279
+
280
+ return f' {self.output} = ' \
281
+ f'Transpose({self.input_var}, ' \
282
+ f'tensor_size: {self.tensor_size}, ' \
283
+ f'{self.index1}, ' \
284
+ f'{self.index2})'
285
+
286
+ def __eq__(self, other):
287
+ if isinstance(other, Transpose):
288
+ return self.tensor_size == other.tensor_size and \
289
+ self.index1 == other.index1 and \
290
+ self.index2 == other.index2 and \
291
+ self.output == other.output and \
292
+ self.input_var == other.input_var
293
+ else:
294
+ return False
295
+
296
+
297
+ class GetItem(Constraint):
298
+
299
+ def __init__(self, tensor_size, index, res, input_var):
300
+ """
301
+ Constraint for getting item given a tensor size
302
+ :param tensor_size: actual number
303
+ :param index: actual number representing the index
304
+ :param res: dimension variable to carry the item we get
305
+ :param input_var: a tensor variable from which we will get item
306
+ """
307
+ assert isinstance(res, DVar)
308
+
309
+ self.res = res
310
+ self.tensor_size = tensor_size
311
+ self.index = index
312
+ self.input_var = input_var
313
+
314
+ def __repr__(self):
315
+ return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})'
316
+
317
+ def __eq__(self, other):
318
+ if isinstance(other, GetItem):
319
+ return self.res == other.res and \
320
+ self.tensor_size == other.tensor_size and \
321
+ self.index == other.index and \
322
+ self.input_var == other.input_var
323
+ else:
324
+ return False
325
+
326
+ class GetItemTensor(Constraint):
327
+
328
+ def __init__(self, tensor_size, index_tuple, res, input_var):
329
+ """
330
+ Constraint for getting item given a tensor size
331
+ However, when the argument is a tuple, we will
332
+ expect a tensor
333
+ :param tensor_size: actual number representing the rank
334
+ :param index_tuple: tuple for indexing
335
+ :param res: tensor variable to carry the item we get
336
+ :param input_var: a tensor variable from which we will get item
337
+ """
338
+ assert isinstance(res, TVar)
339
+
340
+ self.res = res
341
+ self.tensor_size = tensor_size
342
+ self.index_tuple = index_tuple
343
+ self.input_var = input_var
344
+
345
+ def __repr__(self):
346
+ return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})'
347
+
348
+ def __eq__(self, other):
349
+ if isinstance(other, GetItemTensor):
350
+ return self.res == other.res and \
351
+ self.tensor_size == other.tensor_size and \
352
+ self.index_tuple == other.index_tuple and \
353
+ self.input_var == other.input_var
354
+ else:
355
+ return False
356
+
357
+ class CalcConv(Constraint):
358
+
359
+ def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars):
360
+ """
361
+ :param conv_result: the convolution result
362
+ :param input_var: input to convolution
363
+ :param c_out: output chanel type
364
+ :param kernel: kernel tuple
365
+ """
366
+ self.conv_result = conv_result
367
+ self.input_var = input_var
368
+ self.c_out = c_out
369
+ self.kernel = kernel
370
+ self.padding = padding
371
+ self.stride = stride
372
+ self.dilation = dilation
373
+ self.matching_constraint = matching_constraint_vars
374
+
375
+ def __repr__(self):
376
+ return f'{self.conv_result} =' \
377
+ f' calc-conv({self.input_var},' \
378
+ f' {self.c_out}, {self.kernel}, ' \
379
+ f'{self.padding}, {self.stride},' \
380
+ f' {self.dilation})'
381
+
382
+ def __eq__(self, other):
383
+ if isinstance(other, CalcConv):
384
+ return self.conv_result == other.conv_result and self.input_var == other.input_var and \
385
+ self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \
386
+ and self.stride == other.stride and self.dilation == other.dilation \
387
+ and self.matching_constraint == other.matching_constraint
388
+ else:
389
+ return False
390
+
391
+
392
+ class CalcMaxPool(Constraint):
393
+
394
+ def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars):
395
+ """
396
+ :param maxpool_result: the result of maxpool
397
+ :param input_var: input to convolution
398
+ :param kernel: kernel tuple
399
+ """
400
+ self.maxpool_result = maxpool_result
401
+ self.input_var = input_var
402
+ self.kernel = kernel
403
+ self.padding = padding
404
+ self.stride = stride
405
+ self.dilation = dilation
406
+ self.matching_constraint = matching_constraint_vars
407
+
408
+ def __repr__(self):
409
+ return f'{self.maxpool_result} =' \
410
+ f' calc-maxpool({self.input_var},' \
411
+ f' {self.kernel}, ' \
412
+ f'{self.padding}, {self.stride},' \
413
+ f' {self.dilation})'
414
+
415
+ def __eq__(self, other):
416
+ if isinstance(other, CalcMaxPool):
417
+ return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \
418
+ and self.kernel == other.kernel and self.padding == other.padding \
419
+ and self.stride == other.stride and self.dilation == other.dilation \
420
+ and self.matching_constraint == other.matching_constraint
421
+ else:
422
+ return False
423
+
424
+
425
+ class ApplyBroadcasting(Constraint):
426
+ def __init__(self, res1, res2, input1, input2):
427
+ """
428
+ :param res1: resulting tensor 1
429
+ :param res2: resulting tensor 2
430
+ :param input1: tensor variable 1
431
+ :param input2: tensor variable 2
432
+ """
433
+ self.res1 = res1
434
+ self.res2 = res2
435
+ self.input1 = input1
436
+ self.input2 = input2
437
+
438
+ def __eq__(self, other):
439
+ if isinstance(other, ApplyBroadcasting):
440
+ return self.res1 == other.res1 \
441
+ and self.res2 == other.res2 \
442
+ and self.input1 == other.input1 \
443
+ and self.input2 == other.input2
444
+ else:
445
+ return False
446
+
447
+ def __repr__(self):
448
+ return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})'
449
+
450
+
451
+ class CalcProduct(Constraint):
452
+ """
453
+ Given correct dimensions, calculate the product for flatten accounting for Dyn
454
+ """
455
+ def __init__(self, start, end, flattened, dims_to_flatten):
456
+ """
457
+ :param start: start index
458
+ :param end: end index
459
+ :param flattened: variable to store the product
460
+ :param dims_to_flatten: the type which we will flatten
461
+ """
462
+ assert isinstance(dims_to_flatten, list)
463
+ assert isinstance(flattened, TVar)
464
+ assert isinstance(start, int)
465
+ assert isinstance(end, int)
466
+
467
+ self.start = start
468
+ self.end = end
469
+ self.dims_to_flatten = dims_to_flatten
470
+ self.flattened = flattened
471
+
472
+ def __eq__(self, other):
473
+ if isinstance(other, CalcProduct):
474
+ return self.start == other.start and self.end == other.end and \
475
+ self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened
476
+
477
+ else:
478
+ return False
479
+
480
+ def __repr__(self):
481
+ return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})'
482
+
483
+
484
+ class TVar:
485
+ """
486
+ Tensor variable with no tensor constructor
487
+ """
488
+ def __init__(self, tvar):
489
+ """
490
+ :param tvar: tensor variable
491
+ """
492
+ self.tvar = tvar
493
+
494
+ def __repr__(self):
495
+ return f'TV({self.tvar})'
496
+
497
+ def __eq__(self, other):
498
+ if isinstance(other, TVar):
499
+ return self.tvar == other.tvar
500
+ else:
501
+ return False
502
+
503
+
504
+ class DVar:
505
+ """
506
+ Dimension variable
507
+ """
508
+ def __init__(self, c):
509
+ """
510
+ :param c: character or number
511
+ """
512
+ self.c = c
513
+
514
+ def __repr__(self):
515
+ return f'DV({self.c})'
516
+
517
+ def __eq__(self, other):
518
+ if isinstance(other, DVar):
519
+ return self.c == other.c
520
+ else:
521
+ return False
522
+
523
+
524
+ class BVar:
525
+ """
526
+ Boolean variable
527
+ """
528
+ def __init__(self, c):
529
+ """
530
+ :param c: character or number
531
+ """
532
+ self.c = c
533
+
534
+ def __repr__(self):
535
+ return f'BV({self.c})'
536
+
537
+ def __eq__(self, other):
538
+ if isinstance(other, BVar):
539
+ return self.c == other.c
540
+ else:
541
+ return False
542
+
543
+
544
+ def is_algebraic_expression(constraint):
545
+ if isinstance(constraint, BinConstraintD):
546
+ return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod]
547
+ else:
548
+ return isinstance(constraint, Prod)
549
+
550
+
551
+ def is_bool_expr(constraint):
552
+ if isinstance(constraint, BinConstraintD):
553
+ return constraint.op in [op_gt, op_lt, op_neq, op_eq]
554
+ else:
555
+ return isinstance(constraint, (BVar, Conj, Disj))
556
+
557
+ def is_dim(d):
558
+ return isinstance(d, (DVar, int)) or d == Dyn
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-decorators
2
+ # mypy: allow-untyped-defs
3
+ import torch
4
+ import operator
5
+ import warnings
6
+ from typing import Callable, Dict, Iterable
7
+
8
+ from torch.fx._symbolic_trace import _assert_is_none
9
+ from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \
10
+ Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \
11
+ TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound
12
+ from torch.fx.experimental.migrate_gradual_types.operation import \
13
+ op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul
14
+ from torch.fx.node import Target, Node
15
+ from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \
16
+ gen_bvar
17
+
18
+ from torch.fx.tensor_type import Dyn, TensorType
19
+ from torch.nn.modules.conv import Conv2d
20
+ from torch.nn.modules.batchnorm import BatchNorm2d
21
+
22
+ _INFERENCE_RULES: Dict[Target, Callable] = {}
23
+
24
+ MAX_TENSOR_RANK = 4
25
+
26
+ def register_inference_rule(call_target):
27
+ def register(fn):
28
+ if call_target in _INFERENCE_RULES:
29
+ raise RuntimeError(f'Inference rule already registered for {call_target}!')
30
+ _INFERENCE_RULES[call_target] = fn
31
+ return fn
32
+ return register
33
+
34
+
35
+ def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter):
36
+ d, counter = gen_tensor_dims(n, counter)
37
+ c1 = BinConstraintT(input, TensorType(d), op_eq)
38
+ start_dim = n if start_dim == -1 else abs(start_dim)
39
+ end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1
40
+ c2 = CalcProduct(start_dim, end_dim, flattened, d)
41
+ nat_constraints = gen_nat_constraints(d)
42
+ return Conj([c1, c2, *nat_constraints]), counter
43
+
44
+
45
+ @register_inference_rule(getattr)
46
+ def get_attr_inference_rule(n: Node, symbols, constraints, counter):
47
+ """
48
+ If the attribute is "device" then the tensor shape is preserved
49
+ """
50
+ assert isinstance(n.args[0], Node)
51
+ assert isinstance(n.args[1], str)
52
+ output, counter = gen_tvar(counter)
53
+ symbols[n] = output
54
+
55
+ input = symbols[n.args[0]]
56
+ attr = n.args[1]
57
+
58
+ if attr == 'device':
59
+ return [BinConstraintT(input, output, op_eq)], counter
60
+ else:
61
+ raise NotImplementedError('Not yet implemented')
62
+
63
+ @register_inference_rule(torch.bmm)
64
+ def bmm_inference_rule(n: Node, symbols, constraints, counter):
65
+ """
66
+ Constraints that match the input to a size 3 tensor
67
+ and switch the dimensions according to the rules
68
+ of batch multiplication
69
+ """
70
+ assert isinstance(n.args[0], Node)
71
+ assert isinstance(n.args[1], Node)
72
+
73
+ bmm_output, counter = gen_tvar(counter)
74
+ symbols[n] = bmm_output
75
+
76
+ bmm_input1 = symbols[n.args[0]]
77
+ bmm_input2 = symbols[n.args[1]]
78
+
79
+ dims_input1, counter = gen_tensor_dims(3, counter)
80
+ dims_input2, counter = gen_tensor_dims(3, counter)
81
+
82
+ inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
83
+ BinConstraintT(bmm_input2, Dyn, op_eq),
84
+ BinConstraintT(bmm_output, Dyn, op_eq)])
85
+
86
+ input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq),
87
+ BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
88
+ BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)])
89
+
90
+ input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq),
91
+ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
92
+ BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)])
93
+
94
+ consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)]
95
+
96
+ batch_size, counter = gen_dvar(counter)
97
+
98
+ inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq),
99
+ BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq),
100
+ BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq),
101
+ *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])])
102
+
103
+ return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
104
+
105
+
106
+ @register_inference_rule("index_select")
107
+ def index_select_inference_rule(n: Node, symbols, constraints, counter):
108
+ """
109
+ We constrain the second argument to a vector or Dyn.
110
+ The output replaces the input with the shape of the vector
111
+ at the position given by the index (first argument)
112
+ """
113
+ # print(n.args)
114
+ assert isinstance(n.args[0], Node)
115
+ assert isinstance(n.args[1], int)
116
+ assert isinstance(n.args[2], Node)
117
+
118
+
119
+
120
+ index_select, counter = gen_tvar(counter)
121
+ symbols[n] = index_select
122
+
123
+ dims, counter = gen_tensor_dims(1, counter)
124
+
125
+ # equality constraint
126
+ is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq)
127
+ is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq)
128
+
129
+ c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select)
130
+ for i in range(MAX_TENSOR_RANK)])])
131
+ c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select)
132
+ for i in range(MAX_TENSOR_RANK)])])
133
+
134
+ return [Disj([c2, c3])], counter
135
+
136
+
137
+ @register_inference_rule("expand")
138
+ def expand_inference_rule(n: Node, symbols, constraints, counter):
139
+ """
140
+ We generate the exact constraints as we do for tensor additions but we constraint
141
+ the rank of this expression to be equal to len(n.args[1:]) so that only
142
+ those cases get considered for the output
143
+ """
144
+ assert isinstance(n.args[0], Node)
145
+
146
+ # define the output for expand
147
+ expand, counter = gen_tvar(counter)
148
+ symbols[n] = expand
149
+
150
+ # since we do not have two nodes here, we will construct an argument variable
151
+ e1 = symbols[n.args[0]]
152
+ e2, counter = gen_tvar(counter)
153
+
154
+ e2_nat_constraints = []
155
+ for arg in n.args[1:]:
156
+ assert isinstance(arg, (Node, int))
157
+ if isinstance(arg, Node):
158
+ assert isinstance(symbols[arg], DVar)
159
+ e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq))
160
+
161
+ e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq)
162
+
163
+ constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand)
164
+
165
+ # constraint the output size
166
+ dims, counter = gen_tensor_dims(len(n.args[1:]), counter)
167
+ nat_constraints = gen_nat_constraints(dims)
168
+ c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints]
169
+ constraints += c
170
+
171
+ return constraints, counter
172
+
173
+
174
+ @register_inference_rule(torch.nn.functional.gelu)
175
+ @register_inference_rule(torch.nn.functional.dropout)
176
+ @register_inference_rule(torch.nn.functional.softmax)
177
+ @register_inference_rule("detach")
178
+ @register_inference_rule("to")
179
+ @register_inference_rule("int")
180
+ @register_inference_rule("long")
181
+ @register_inference_rule("contiguous")
182
+ @register_inference_rule(torch.ones)
183
+ @register_inference_rule(torch.zeros)
184
+ def equality_inference_rule(n: Node, symbols, constraints, counter):
185
+ """
186
+ We generate the constraint: input = output
187
+ """
188
+ output, counter = gen_tvar(counter)
189
+ symbols[n] = output
190
+
191
+ if isinstance(n.args[0], Node):
192
+ input = symbols[n.args[0]]
193
+ if isinstance(input, TVar):
194
+ return [BinConstraintT(input, output, op_eq)], counter
195
+
196
+ # then we have dimension variables
197
+ else:
198
+ for arg in n.args:
199
+ assert isinstance(symbols[arg], DVar)
200
+ my_size = [symbols[arg] for arg in n.args]
201
+ return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
202
+
203
+ elif isinstance(n.args[0], tuple):
204
+ # then the tuple is the size
205
+ assert len(n.args[0]) <= 4
206
+ my_size = [symbols[arg] for arg in n.args[0]]
207
+ return [BinConstraintT(output, TensorType(my_size), op_eq)], counter
208
+ else:
209
+ raise NotImplementedError('Method not yet implemented')
210
+
211
+
212
+ @register_inference_rule("transpose")
213
+ def transpose_inference_rule(n: Node, symbols, constraints, counter):
214
+ """
215
+ Can be considered as a sequence of two index selects, so we generate constraints accordingly
216
+ """
217
+ assert isinstance(n.args[0], Node)
218
+ assert isinstance(n.args[1], int)
219
+ assert isinstance(n.args[2], int)
220
+
221
+ output, counter = gen_tvar(counter)
222
+ symbols[n] = output
223
+
224
+ from_arg = symbols[n.args[0]]
225
+ assert isinstance(from_arg, TVar)
226
+
227
+ # input and output are dyn
228
+ is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)])
229
+
230
+ # or input is a tensor and we actually do the replacement
231
+ c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)])
232
+
233
+ return [Disj([is_dyn, c3])], counter
234
+
235
+
236
+ @register_inference_rule("type_as")
237
+ def type_inference_rule(n: Node, symbols, constraints, counter):
238
+ """
239
+ We generate the constraint: input = output
240
+ """
241
+ assert isinstance(n.args[0], Node)
242
+ assert isinstance(n.args[1], Node)
243
+
244
+ output, counter = gen_tvar(counter)
245
+ symbols[n] = output
246
+
247
+ from_arg = symbols[n.args[0]]
248
+ to_arg = symbols[n.args[1]]
249
+
250
+ assert isinstance(from_arg, TVar)
251
+ assert isinstance(to_arg, TVar)
252
+
253
+ return [BinConstraintT(from_arg, to_arg, op_consistency),
254
+ BinConstraintT(output, to_arg, op_eq)], counter
255
+
256
+ @register_inference_rule("masked_fill_")
257
+ def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
258
+ """
259
+ Similar to addition. For now we implement the constraints when
260
+ the argument is a boolean tensor. There is also a case for when
261
+ it is a condition. We will leave this out for now.
262
+ """
263
+
264
+ assert isinstance(n.args[0], Node)
265
+ assert isinstance(n.args[1], Node)
266
+
267
+ # We will retrieve the type variables from the symbol table
268
+ # and confirm they are tensor variables
269
+
270
+ e1 = symbols[n.args[0]]
271
+ e2 = symbols[n.args[1]]
272
+
273
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
274
+ masked_fill_tensor, counter = gen_tvar(counter)
275
+ symbols[n] = masked_fill_tensor
276
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor)
277
+ else:
278
+ raise NotImplementedError('Not yet implemented')
279
+
280
+
281
+ @register_inference_rule(torch.nn.functional.embedding)
282
+ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
283
+ assert isinstance(n.args[0], Node)
284
+
285
+ embedding_dim_weights = symbols[n.args[1]]
286
+
287
+ # will treat this as a static shape. So we will not use matching.
288
+ weight_dims, counter = gen_tensor_dims(2, counter)
289
+ equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
290
+ embedding_dim = weight_dims[1]
291
+ constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
292
+ return [equality_constraint] + constraints, counter
293
+
294
+
295
+ @register_inference_rule(torch.nn.modules.sparse.Embedding)
296
+ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
297
+ """
298
+ The output shape differs from the input shape in the last dimension
299
+ """
300
+ assert isinstance(n.args[0], Node)
301
+ return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)
302
+
303
+
304
+ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):
305
+
306
+ embedding_output, counter = gen_tvar(counter)
307
+ symbols[n] = embedding_output
308
+ embedding_input = symbols[n.args[0]]
309
+
310
+ input_dyn = BinConstraintT(embedding_input, Dyn, op_eq)
311
+ output_dyn = BinConstraintT(embedding_output, Dyn, op_eq)
312
+
313
+ c1 = Conj([input_dyn, output_dyn])
314
+ c2 = []
315
+
316
+ for i in range(1, MAX_TENSOR_RANK):
317
+ new_dims, counter = gen_tensor_dims(i, counter)
318
+ nat_constraints = gen_nat_constraints(new_dims)
319
+
320
+ # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases
321
+ c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq),
322
+ BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] +
323
+ nat_constraints)
324
+ c2.append(c_tensor_i)
325
+
326
+ return [Disj([c1, Disj(c2)])], counter
327
+
328
+
329
+ @register_inference_rule(torch.tensor)
330
+ def tensor_inference_rule(n: Node, symbols, constraints, counter):
331
+ """
332
+ If the tensor is a scalar, we will skip it since we
333
+ do not support scalars yet. We will add support in the future
334
+ if it's needed. For our examples so far, scalars are not needed.
335
+ """
336
+ return [], counter
337
+
338
+
339
+ @register_inference_rule("reshape")
340
+ @register_inference_rule("view")
341
+ def view_inference_rule(n: Node, symbols, constraints, counter):
342
+ """
343
+ Similar to reshape but with an extra condition on the strides
344
+ """
345
+ assert isinstance(n.args[0], Node)
346
+
347
+ # generate the new variable
348
+ my_view, counter = gen_tvar(counter)
349
+ symbols[n] = my_view
350
+
351
+
352
+ src_var = symbols[n.args[0]]
353
+ t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape
354
+ t2_type = []
355
+ num_constraints = []
356
+
357
+ for t in t2:
358
+ if t == -1:
359
+ var, counter = gen_dvar(counter)
360
+ t2_type.append(var)
361
+ num_constraints.append(BinConstraintD(var, Dyn, op_neq))
362
+
363
+ else:
364
+ num_constraints.append(BinConstraintD(t, Dyn, op_neq))
365
+ t2_type.append(t)
366
+
367
+ t2_type = TensorType(t2_type) # type: ignore[assignment]
368
+
369
+ c1 = BinConstraintT(my_view, t2_type, op_eq)
370
+ c2 = CanReshape(src_var, t2_type)
371
+
372
+ # TODO: add the extra check mentioned here:
373
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
374
+
375
+ return [c1, c2] + num_constraints, counter # type: ignore[operator]
376
+
377
+
378
+ @register_inference_rule("size")
379
+ def size_inference_rule(n: Node, symbols, constraints, counter):
380
+ """
381
+ The constraint is just lhs = rhs.
382
+ Ex: size = input_ids.size()
383
+ """
384
+
385
+
386
+ if len(n.args) == 1:
387
+ # generate the new variable
388
+ size, counter = gen_tvar(counter)
389
+ symbols[n] = size
390
+ input = symbols[n.args[0]]
391
+ c = BinConstraintT(input, size, op_eq)
392
+ return [c], counter
393
+
394
+ elif len(n.args) == 2:
395
+ # TODO: review this rule; should input = dyn; output = dyn be included here?
396
+ if isinstance(n.args[1], int):
397
+ # generate the new variable
398
+ size_index, counter = gen_dvar(counter)
399
+ symbols[n] = size_index
400
+ input = symbols[n.args[0]]
401
+ c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)]
402
+ c3 = BinConstraintD(0, size_index, op_leq)
403
+
404
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
405
+ output_dyn = BinConstraintD(size_index, Dyn, op_eq)
406
+ c1 = Conj([input_dyn, output_dyn])
407
+
408
+ return [Disj([c1, Conj([Disj(c2), c3])])], counter
409
+
410
+ else:
411
+ raise NotImplementedError
412
+
413
+ else:
414
+ raise NotImplementedError
415
+
416
+
417
+ def range_check(i, n):
418
+ """
419
+ Checks if an index i is within range of a size n list
420
+ Args:
421
+ i: index
422
+ n: list size
423
+
424
+ Returns: Boolean
425
+ """
426
+ if i >= 0:
427
+ return T() if i < n else F()
428
+ else:
429
+ return T() if i >= n else F()
430
+
431
+
432
+ @register_inference_rule(torch.cumsum)
433
+ def cumsum_inference_rule(n: Node, symbols, constraints, counter):
434
+ """
435
+ Input and output shapes should be equal
436
+ We should verify that the index is valid
437
+ """
438
+ assert isinstance(n.args[0], Node)
439
+ arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"]
440
+ assert isinstance(arg_1, int)
441
+
442
+ output, counter = gen_tvar(counter)
443
+ symbols[n] = output
444
+ input = symbols[n.args[0]]
445
+
446
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
447
+ output_dyn = BinConstraintT(output, Dyn, op_eq)
448
+ c1 = Conj([input_dyn, output_dyn])
449
+ c2 = []
450
+ for i in range(1, MAX_TENSOR_RANK + 1):
451
+ new_dims, counter = gen_tensor_dims(i, counter)
452
+
453
+ nat_constraints = gen_nat_constraints(new_dims)
454
+
455
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq),
456
+ BinConstraintT(output, TensorType(new_dims), op_eq)] +
457
+ [range_check(arg_1, i)] + nat_constraints)
458
+
459
+ c2.append(c_tensor_i)
460
+ dyn_or_tensor = Disj([c1, Disj(c2)])
461
+ return [dyn_or_tensor], counter
462
+
463
+
464
+ @register_inference_rule(_assert_is_none)
465
+ def assert_inference_rule(n: Node, symbols, constraints, counter):
466
+ assert len(n.users) == 0
467
+ return [], counter
468
+
469
+
470
+ @register_inference_rule(operator.getitem)
471
+ def getitem_inference_rule(n: Node, symbols, constraints, counter):
472
+ assert isinstance(n.args[0], Node)
473
+
474
+ # dimension output case
475
+ if isinstance(n.args[1], int):
476
+ # create and store the new dimension variable
477
+ get_item_output, counter = gen_dvar(counter)
478
+ symbols[n] = get_item_output
479
+
480
+ # retrieve arg variables
481
+ get_item_arg = symbols[n.args[0]]
482
+ assert isinstance(get_item_arg, TVar)
483
+
484
+
485
+ # if the input is dynamic, we accept any index and return
486
+ # a dynamic dimension as output
487
+ input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
488
+ output_dyn = BinConstraintD(get_item_output, Dyn, op_eq)
489
+ c1 = Conj([input_dyn, output_dyn])
490
+
491
+ # if the input is a tensor,
492
+ # generate a getItem constraint which will be expanded based on the
493
+ # tensor dimension.
494
+
495
+ c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)]
496
+
497
+
498
+ # since the output is a dimension, we make sure it's a natural number
499
+ # added as a conjunction to the disjunction of c2
500
+ c3 = BinConstraintD(0, get_item_output, op_leq)
501
+ return [Disj([c1, Conj([Disj(c2), c3])])], counter
502
+
503
+ # tensor output case
504
+ elif isinstance(n.args[1], tuple):
505
+ # create and store the new tensor variable
506
+ get_item_output, counter = gen_tvar(counter)
507
+ symbols[n] = get_item_output
508
+
509
+ # retrieve arg variables
510
+ if n.args[0] in symbols:
511
+ get_item_arg = symbols[n.args[0]]
512
+ assert isinstance(get_item_arg, TVar)
513
+
514
+ input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq)
515
+ output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment]
516
+ c1 = Conj([input_dyn, output_dyn])
517
+
518
+ c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc]
519
+ for i in range(MAX_TENSOR_RANK)]
520
+ else:
521
+ # TODO: we should figure out why there is a key-error here.
522
+ return [], counter
523
+
524
+ return [Disj([c1, *c2])], counter
525
+
526
+ else:
527
+ raise RuntimeError('Method not yet implemented')
528
+
529
+
530
+ @register_inference_rule(operator.gt)
531
+ def gt_inference_rule(n: Node, symbols, constraints, counter):
532
+ assert isinstance(n.args[0], (Node, int))
533
+ assert isinstance(n.args[1], (Node, int))
534
+
535
+ # We make sure this node will not be used again. We do not
536
+ # generate a constraint about that node. Only about the operands.
537
+
538
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
539
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
540
+
541
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
542
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
543
+ gt_tensor, counter = gen_tvar(counter)
544
+ symbols[n] = gt_tensor
545
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor)
546
+
547
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
548
+ # This is meant to be used for flow analysis only
549
+ gt_constraint = BinConstraintD(e1, e2, op_gt)
550
+
551
+ my_gt, counter = gen_bvar(counter)
552
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
553
+ return [equality_constraint], counter
554
+
555
+ else:
556
+ raise RuntimeError('Sort Mismatch')
557
+
558
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
559
+ if isinstance(e1, DVar):
560
+ # This is meant to be used for flow analysis only
561
+ gt_constraint = BinConstraintD(e1, e2, op_gt)
562
+
563
+ my_gt, counter = gen_bvar(counter)
564
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
565
+ return [equality_constraint], counter
566
+
567
+ elif isinstance(e1, TVar) and isinstance(e2, int):
568
+ # then we made the wrong assumption about the argument being a tensor
569
+ # so we should fix the assumption
570
+ warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.')
571
+
572
+ new_e1, counter = gen_dvar(counter)
573
+ symbols[n.args[0]] = new_e1
574
+ symbols[n.args[0]]
575
+
576
+ gt_constraint = BinConstraintD(new_e1, e2, op_gt)
577
+
578
+ my_gt, counter = gen_bvar(counter)
579
+ equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq)
580
+ return [equality_constraint], counter
581
+
582
+ else:
583
+ raise NotImplementedError('Method not yet implemented')
584
+
585
+ else:
586
+ raise NotImplementedError('Method not yet implemented')
587
+
588
+
589
+ @register_inference_rule(operator.eq)
590
+ def eq_inference_rule(n: Node, symbols, constraints, counter):
591
+ assert isinstance(n.args[0], (Node, int))
592
+ assert isinstance(n.args[1], (Node, int))
593
+
594
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
595
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
596
+
597
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
598
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
599
+ eq_tensor, counter = gen_tvar(counter)
600
+ symbols[n] = eq_tensor
601
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor)
602
+
603
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
604
+ # This is meant to be used for flow analysis only
605
+ eq_constraint = BinConstraintD(e1, e2, op_eq)
606
+
607
+ my_eq, counter = gen_bvar(counter)
608
+ equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
609
+ return [equality_constraint], counter
610
+
611
+ else:
612
+ raise RuntimeError('Sort Mismatch')
613
+
614
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
615
+ if isinstance(e1, DVar):
616
+ # This is meant to be used for flow analysis only
617
+ eq_constraint = BinConstraintD(e1, e2, op_eq)
618
+
619
+ my_eq, counter = gen_bvar(counter)
620
+ equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq)
621
+ return [equality_constraint], counter
622
+ else:
623
+ raise NotImplementedError('Method not yet implemented')
624
+ else:
625
+ raise NotImplementedError('Method not yet implemented')
626
+
627
+ @register_inference_rule(operator.ne)
628
+ def neq_inference_rule(n: Node, symbols, constraints, counter):
629
+ """
630
+ Translates to inconsistent in gradual types.
631
+ To prove inequality, we should prove that
632
+ tensors are either different sizes or
633
+ disagree on at least one dimension
634
+
635
+ This is a WIP (works when the condition
636
+ is false. We are working on making this operation work
637
+ when the condition is true as well)
638
+ """
639
+ assert isinstance(n.args[0], Node)
640
+ assert isinstance(n.args[1], tuple)
641
+
642
+ # implementing for size 3 and 4
643
+ if len(n.args[1]) == 3:
644
+
645
+ assert isinstance(n.args[1][0], (Node, int))
646
+ assert isinstance(n.args[1][1], (Node, int))
647
+ assert isinstance(n.args[1][2], (Node, int))
648
+
649
+ lhs = symbols[n.args[0]]
650
+
651
+ b, counter = gen_tensor_dims(4, counter)
652
+ input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq)
653
+
654
+ d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
655
+ d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
656
+ d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
657
+
658
+ # dimensions not equal
659
+ my_ne, counter = gen_bvar(counter)
660
+ neq_1 = BinConstraintD(d1, b[0], op_neq)
661
+ neq_2 = BinConstraintD(d2, b[1], op_neq)
662
+ neq_3 = BinConstraintD(d3, b[2], op_neq)
663
+
664
+ # dimensions inconsistent
665
+ dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1])
666
+ dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2])
667
+ dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3])
668
+
669
+ dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3])
670
+
671
+ # we are covering size 3 and 4 only for now
672
+ ne_constraint = Conj([input_is_size3, dims_inconsistent])
673
+
674
+ my_ne, counter = gen_bvar(counter)
675
+ equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
676
+
677
+ elif len(n.args[1]) == 4:
678
+
679
+ assert isinstance(n.args[1][0], (Node, int))
680
+ assert isinstance(n.args[1][1], (Node, int))
681
+ assert isinstance(n.args[1][2], (Node, int))
682
+ assert isinstance(n.args[1][3], (Node, int))
683
+
684
+ lhs = symbols[n.args[0]]
685
+
686
+ b1, counter = gen_dvar(counter)
687
+ b2, counter = gen_dvar(counter)
688
+ b3, counter = gen_dvar(counter)
689
+ b4, counter = gen_dvar(counter)
690
+
691
+ input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq)
692
+
693
+ d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]]
694
+ d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]]
695
+ d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]]
696
+ d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]]
697
+
698
+ # dimensions not equal
699
+ my_ne, counter = gen_bvar(counter)
700
+ neq_1 = BinConstraintD(d1, b1, op_neq)
701
+ neq_2 = BinConstraintD(d2, b2, op_neq)
702
+ neq_3 = BinConstraintD(d3, b3, op_neq)
703
+ neq_4 = BinConstraintD(d4, b4, op_neq)
704
+
705
+ # dimensions to inconsistent
706
+ dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1])
707
+ dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2])
708
+ dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3])
709
+ dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4])
710
+
711
+ dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4])
712
+
713
+ ne_constraint = Conj([input_is_size4, dims_inconsistent])
714
+
715
+ my_ne, counter = gen_bvar(counter)
716
+
717
+ equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq)
718
+
719
+ else:
720
+ raise NotImplementedError('Method not yet implemented')
721
+
722
+ return [equality_constraint], counter
723
+
724
+
725
+ @register_inference_rule(operator.lt)
726
+ def lt_inference_rule(n: Node, symbols, constraints, counter):
727
+ assert isinstance(n.args[0], (Node, int))
728
+ assert isinstance(n.args[1], (Node, int))
729
+
730
+ # We make sure this node will not be used again. We do not
731
+ # generate a constraint about that node. Only about the operands.
732
+
733
+ e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0]
734
+ e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1]
735
+
736
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
737
+ if isinstance(e1, TVar) and isinstance(e2, TVar):
738
+ lt_tensor, counter = gen_tvar(counter)
739
+ symbols[n] = lt_tensor
740
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor)
741
+
742
+ elif isinstance(e1, DVar) and isinstance(e2, DVar):
743
+ # This is meant to be used for flow analysis only
744
+ lt_constraint = BinConstraintD(e1, e2, op_lt)
745
+
746
+ my_lt, counter = gen_bvar(counter)
747
+ equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
748
+ return [equality_constraint], counter
749
+
750
+ else:
751
+ raise RuntimeError('Sort Mismatch')
752
+
753
+ elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node):
754
+ if isinstance(e1, DVar):
755
+ # This is meant to be used for flow analysis only
756
+ lt_constraint = BinConstraintD(e1, e2, op_lt)
757
+
758
+ my_lt, counter = gen_bvar(counter)
759
+ equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq)
760
+ return [equality_constraint], counter
761
+ else:
762
+ raise NotImplementedError('Method not yet implemented')
763
+
764
+ else:
765
+ raise NotImplementedError('Method not yet implemented')
766
+
767
+
768
+ @register_inference_rule(torch.full)
769
+ def full_inference_rule(n: Node, symbols, constraints, counter):
770
+ full, counter = gen_tvar(counter)
771
+ symbols[n] = full
772
+ res = []
773
+
774
+ assert isinstance(n.args[0], Iterable)
775
+ for arg in n.args[0]:
776
+ dim = arg if isinstance(arg, int) else symbols[arg]
777
+ res.append(dim)
778
+ c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type]
779
+ return [c], counter
780
+
781
+
782
+ # TODO normalize index
783
+ @register_inference_rule(torch.arange)
784
+ def arange_inference_rule(n: Node, symbols, constraints, counter):
785
+ start = 0
786
+ step = 1
787
+
788
+ if len(n.args) == 1:
789
+ end = symbols[n.args[0]]
790
+ else:
791
+ raise NotImplementedError('Not yet implemented')
792
+
793
+ # int((end - start) / step)
794
+ d1, counter = gen_dvar(counter)
795
+ size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq)
796
+ arange, counter = gen_tvar(counter)
797
+ symbols[n] = arange
798
+
799
+ # either the a parameter is a number or it is Dyn
800
+ c1 = Disj([BinConstraintD(end, Dyn, op_eq),
801
+ BinConstraintD(start, Dyn, op_eq),
802
+ BinConstraintD(step, Dyn, op_eq)])
803
+ c2 = BinConstraintD(d1, Dyn, op_eq)
804
+ both_dyn = Conj([c1, c2])
805
+
806
+ c11 = Conj([BinConstraintD(end, Dyn, op_neq),
807
+ BinConstraintD(start, Dyn, op_neq),
808
+ BinConstraintD(step, Dyn, op_neq)])
809
+ c22 = BinConstraintD(d1, Dyn, op_neq)
810
+ both_numbers = Conj([c11, c22, size_constraint])
811
+
812
+ return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
813
+
814
+ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var):
815
+ # additional vars that don't correspond to expressions
816
+ e11, counter = gen_tvar(counter)
817
+ e22, counter = gen_tvar(counter)
818
+
819
+ # generate constraints
820
+ c1 = TGreatestUpperBound(output_var, e11, e22)
821
+ c2 = ApplyBroadcasting(e11, e22, e1, e2)
822
+ c3 = BinConstraintT(e11, e22, op_consistency)
823
+ return [c1, c2, c3], counter
824
+
825
+
826
+ @register_inference_rule(operator.mul)
827
+ @register_inference_rule(torch.ne)
828
+ @register_inference_rule("ne")
829
+ @register_inference_rule(torch.add)
830
+ @register_inference_rule(operator.add)
831
+ def broadcasting_inference_rule(n: Node, symbols, constraints, counter):
832
+
833
+ op_code = None
834
+ if n.target == operator.add or n.target == torch.add:
835
+ op_code = op_add
836
+ elif n.target == operator.mul:
837
+ op_code = op_mul
838
+
839
+ if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
840
+ if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar):
841
+ my_output, counter = gen_tvar(counter)
842
+ symbols[n] = my_output
843
+ e1 = symbols[n.args[0]]
844
+ e2 = symbols[n.args[1]]
845
+
846
+ return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output)
847
+ else:
848
+ raise NotImplementedError('Method not yet implemented')
849
+
850
+ elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)):
851
+ if isinstance(symbols[n.args[0]], TVar):
852
+ my_output, counter = gen_tvar(counter)
853
+ symbols[n] = my_output
854
+ e1 = symbols[n.args[0]]
855
+ return [BinConstraintT(my_output, e1, op_eq)], counter
856
+ elif isinstance(symbols[n.args[0]], DVar):
857
+ my_output, counter = gen_dvar(counter)
858
+ symbols[n] = my_output
859
+ e1 = symbols[n.args[0]]
860
+
861
+ # we will propagate the runtime value here since this is regular addition
862
+ c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq),
863
+ BinConstraintD(0, my_output, op_leq)])
864
+ return [c], counter
865
+
866
+ elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)):
867
+ if isinstance(symbols[n.args[1]], TVar):
868
+ my_output, counter = gen_tvar(counter)
869
+ symbols[n] = my_output
870
+ e2 = symbols[n.args[1]]
871
+ return [BinConstraintT(my_output, e2, op_eq)], counter
872
+ elif isinstance(symbols[n.args[1]], DVar):
873
+ my_output, counter = gen_dvar(counter)
874
+ symbols[n] = my_output
875
+ e2 = symbols[n.args[1]]
876
+
877
+ # we will propagate the runtime value here since this is regular addition
878
+ c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq),
879
+ BinConstraintD(0, my_output, op_leq)])
880
+ return [c], counter
881
+
882
+ else:
883
+ raise NotImplementedError('Method not yet implemented')
884
+
885
+ else:
886
+ # TODO generate add constraints for scalar addition
887
+ raise NotImplementedError('Addition not yet implemented')
888
+
889
+
890
+ @register_inference_rule(torch.flatten)
891
+ def flatten_inference_rule(n: Node, symbols, constraints, counter):
892
+ assert isinstance(n.args[0], Node)
893
+
894
+ # generate the new variable
895
+ flattened, counter = gen_tvar(counter)
896
+ symbols[n] = flattened
897
+
898
+ input = symbols[n.args[0]]
899
+
900
+ # set the default start and end dims
901
+ start_dim = 1
902
+ end_dim = -1
903
+
904
+ if len(n.args) > 1:
905
+ assert isinstance(n.args[1], int)
906
+ start_dim = n.args[1]
907
+
908
+ if len(n.args) > 2:
909
+ assert isinstance(n.args[2], int)
910
+ end_dim = n.args[2]
911
+
912
+ c1 = BinConstraintT(input, Dyn, op_eq)
913
+ c2 = BinConstraintT(flattened, Dyn, op_eq)
914
+ both_dyn = Conj([c1, c2])
915
+
916
+ const = []
917
+ for i in range(1, MAX_TENSOR_RANK + 1):
918
+ c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter)
919
+ const.append(c)
920
+
921
+ return [Disj([both_dyn, *const])], counter
922
+
923
+
924
+ @register_inference_rule(torch.nn.functional.layer_norm)
925
+ def layer_norm_functional(n: Node, symbols, constraints, counter):
926
+ """
927
+ We generate the constraint: input = output
928
+ """
929
+ assert isinstance(n.args[0], Node)
930
+ return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
931
+
932
+
933
+ @register_inference_rule(torch.nn.LayerNorm)
934
+ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
935
+ """
936
+ Input and output shapes should be equal.
937
+ Input should be consistent with the normalized_shape
938
+ """
939
+ assert isinstance(n.args[0], Node)
940
+ return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
941
+
942
+
943
+ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
944
+ output, counter = gen_tvar(counter)
945
+ symbols[n] = output
946
+ input = symbols[n.args[0]]
947
+
948
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
949
+ output_dyn = BinConstraintT(output, Dyn, op_eq)
950
+
951
+ c1 = Conj([input_dyn, output_dyn])
952
+
953
+ c2 = []
954
+ for i in range(1, MAX_TENSOR_RANK + 1):
955
+ new_dims_rhs, counter = gen_tensor_dims(i, counter)
956
+ nat_constraints = gen_nat_constraints(new_dims_rhs)
957
+
958
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
959
+ BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
960
+ add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
961
+ nat_constraints)
962
+ c2.append(c_tensor_i)
963
+ return [Disj([c1, Disj(c2)])], counter
964
+
965
+ @register_inference_rule(torch.nn.Dropout)
966
+ @register_inference_rule(torch.nn.ReLU)
967
+ def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):
968
+ """
969
+ Input and output shapes should be equal.
970
+ """
971
+ assert isinstance(n.args[0], Node)
972
+ output, counter = gen_tvar(counter)
973
+ symbols[n] = output
974
+ input = symbols[n.args[0]]
975
+ assert isinstance(input, TVar)
976
+ return [BinConstraintT(input, output, op_eq)], counter
977
+
978
+
979
+ @register_inference_rule(torch.nn.Linear)
980
+ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
981
+ """
982
+ Input and output sizes should be the same except for the last dimension
983
+ If the input is Dyn, then so should the output
984
+ """
985
+ assert isinstance(n.args[0], Node)
986
+ return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)
987
+
988
+
989
+ @register_inference_rule("dim") # type: ignore[attr-defined]
990
+ def torch_dim_inference_rule(n: Node, symbols, constraints, counter):
991
+ assert isinstance(n.args[0], Node)
992
+ my_dim, counter = gen_dvar(counter)
993
+ symbols[n] = my_dim
994
+ input = symbols[n.args[0]]
995
+
996
+ input_dyn = BinConstraintT(input, Dyn, op_eq)
997
+ output_dyn = BinConstraintD(my_dim, Dyn, op_eq)
998
+
999
+ c1 = []
1000
+
1001
+ for i in range(1, MAX_TENSOR_RANK + 1):
1002
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
1003
+
1004
+ c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq),
1005
+ BinConstraintD(my_dim, i, op_eq)])
1006
+ c1.append(c_tensor_i)
1007
+
1008
+ return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
1009
+
1010
+
1011
+ @register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined]
1012
+ def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
1013
+ assert isinstance(n.args[0], Node)
1014
+ weight_dims, counter = gen_tensor_dims(2, counter)
1015
+ equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq)
1016
+ constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter)
1017
+ return [equality_constraint] + constraints, counter
1018
+
1019
+
1020
+ def linear_constraints(n: Node, in_features, out_features, symbols, counter):
1021
+ linear_output, counter = gen_tvar(counter)
1022
+ symbols[n] = linear_output
1023
+ linear_input = symbols[n.args[0]]
1024
+
1025
+ input_dyn = BinConstraintT(linear_input, Dyn, op_eq)
1026
+ output_dyn = BinConstraintT(linear_output, Dyn, op_eq)
1027
+
1028
+ c1 = Conj([input_dyn, output_dyn])
1029
+
1030
+ c2 = []
1031
+ for i in range(1, MAX_TENSOR_RANK + 1):
1032
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
1033
+ new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
1034
+
1035
+ nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
1036
+
1037
+ c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
1038
+ BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
1039
+ add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
1040
+ nat_constraints)
1041
+ c2.append(c_tensor_i)
1042
+ return [Disj([c1, Disj(c2)])], counter
1043
+
1044
+ def add_layer_norm_constraints(input_dim, normalized_dim):
1045
+ """
1046
+ The constraints say that the type has te form: [*, 1024, 1024]
1047
+ while the normalized_dim have the form [1024, 1024]
1048
+ Args:
1049
+ input_dim: Input shape of layer norm
1050
+ normalized_dim: normalized_dim parameter of the module instance
1051
+
1052
+ """
1053
+
1054
+ # in this case we return false since there's a pattern mismatch
1055
+ if len(normalized_dim) > len(input_dim):
1056
+ return [F()]
1057
+
1058
+ else:
1059
+ constraints = []
1060
+ for i, n in zip(reversed(input_dim), reversed(normalized_dim)):
1061
+ constraints.append(BinConstraintD(i, n, op_consistency))
1062
+ return constraints
1063
+
1064
+
1065
+ def add_linear_constraints(dims1, dims2, in_features, out_features):
1066
+ assert len(dims1) == len(dims2)
1067
+ constraints = []
1068
+ for i in range(len(dims1)):
1069
+ if i == len(dims1) - 1:
1070
+ constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
1071
+ constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
1072
+ else:
1073
+ constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))
1074
+
1075
+ return constraints
1076
+
1077
+
1078
+ @register_inference_rule(torch.reshape)
1079
+ def reshape_inference_rule(n: Node, symbols, constraints, counter):
1080
+ assert isinstance(n.args[0], Node)
1081
+
1082
+ # generate the new variable
1083
+ my_reshape, counter = gen_tvar(counter)
1084
+ symbols[n] = my_reshape
1085
+
1086
+ src_var = symbols[n.args[0]]
1087
+ t2 = n.args[1]
1088
+ t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr]
1089
+ c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr]
1090
+ c2 = CanReshape(src_var, t2_type)
1091
+
1092
+ return [c1, c2], counter
1093
+
1094
+
1095
+ @register_inference_rule(BatchNorm2d)
1096
+ def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1097
+ assert isinstance(n.args[0], Node)
1098
+
1099
+ # generate the new variable
1100
+ batchnorm_output, counter = gen_tvar(counter)
1101
+ symbols[n] = batchnorm_output
1102
+ batchnorm_input = symbols[n.args[0]]
1103
+
1104
+ # dim vars
1105
+ d1, counter = gen_dvar(counter)
1106
+ d2, counter = gen_dvar(counter)
1107
+ d3, counter = gen_dvar(counter)
1108
+ d4, counter = gen_dvar(counter)
1109
+
1110
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1111
+
1112
+ c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching)
1113
+ c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq)
1114
+ return [c1, c2, *nat_constraints], counter
1115
+
1116
+
1117
+ @register_inference_rule(torch.nn.AdaptiveAvgPool2d)
1118
+ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1119
+ assert isinstance(n.args[0], Node)
1120
+
1121
+ avg_pool, counter = gen_tvar(counter)
1122
+
1123
+ symbols[n] = avg_pool
1124
+ input_var = symbols[n.args[0]]
1125
+
1126
+ # dim vars
1127
+ d1, counter = gen_dvar(counter)
1128
+ d2, counter = gen_dvar(counter)
1129
+ d3, counter = gen_dvar(counter)
1130
+ d4, counter = gen_dvar(counter)
1131
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1132
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1133
+ c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq)
1134
+
1135
+ return [c1, c2, *nat_constraints], counter
1136
+
1137
+
1138
+ @register_inference_rule(Conv2d)
1139
+ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1140
+ assert isinstance(n.args[0], Node)
1141
+
1142
+ my_conv, counter = gen_tvar(counter)
1143
+ symbols[n] = my_conv
1144
+ input_var = symbols[n.args[0]]
1145
+
1146
+ # dim vars
1147
+ [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
1148
+
1149
+ # c1 = Matching(input_var, TensorType([d1, d2, d3, d4]))
1150
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1151
+
1152
+ # c2 = DConsistency(module_instance.in_channels, d2)
1153
+ c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency)
1154
+
1155
+ c3 = CalcConv(my_conv, input_var,
1156
+ module_instance.out_channels,
1157
+ module_instance.kernel_size,
1158
+ module_instance.padding,
1159
+ module_instance.stride,
1160
+ module_instance.dilation, [d1, d2, d3, d4])
1161
+
1162
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1163
+
1164
+ return [c1, c2, c3, *nat_constraints], counter
1165
+
1166
+
1167
+ @register_inference_rule(torch.nn.MaxPool2d)
1168
+ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter):
1169
+ assert isinstance(n.args[0], Node)
1170
+ maxpool, counter = gen_tvar(counter)
1171
+ symbols[n] = maxpool
1172
+ input_var = symbols[n.args[0]]
1173
+
1174
+ # dim vars
1175
+ [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter)
1176
+
1177
+ c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching)
1178
+
1179
+ c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding,
1180
+ module_instance.stride, module_instance.dilation, [d1, d2, d3, d4])
1181
+
1182
+ nat_constraints = gen_nat_constraints([d1, d2, d3, d4])
1183
+
1184
+ return [c1, c2, *nat_constraints], counter
1185
+
1186
+
1187
+ class ConstraintGenerator:
1188
+ def __init__(self, traced, graph=None):
1189
+ self.traced = traced # traced or tracer.root
1190
+ self.traced_params = dict(self.traced.named_parameters())
1191
+ self.constraints = []
1192
+ self.symbol_dict = {}
1193
+ self.graph = traced.graph if hasattr(traced, 'graph') else graph
1194
+
1195
+
1196
+ def generate_constraints(self, counter=0):
1197
+ """
1198
+ Iterate through every node and generate constraints
1199
+ Effect: self.constraints will be populated with the final constraints
1200
+ """
1201
+ graph = self.graph
1202
+
1203
+ all_constraints = []
1204
+
1205
+ for n in graph.nodes:
1206
+ (constraints, counter) = self.generate_constraints_node(n, counter)
1207
+ all_constraints += constraints
1208
+
1209
+ return Conj(all_constraints), counter
1210
+
1211
+ def generate_constraints_node(self, n: Node, counter):
1212
+ """
1213
+ Generate constraints the given node:
1214
+ Currently supported operations:
1215
+ - Reshape
1216
+ - Add
1217
+ - conv2d
1218
+ """
1219
+
1220
+ if n.op == 'placeholder':
1221
+ x, counter = gen_tvar(counter)
1222
+ self.symbol_dict[n] = x
1223
+
1224
+ my_type = n.type
1225
+
1226
+ if n.type != Dyn and (not isinstance(n.type, TensorType)):
1227
+ if n.type == torch.nn.parameter.Parameter:
1228
+ # since we have a parameter, the shape must be static
1229
+ assert 'example_value' in n.meta
1230
+ my_type = TensorType(n.meta['example_value'].size())
1231
+ else:
1232
+ my_type = Dyn
1233
+
1234
+ c1 = BinConstraintT(my_type, x, op_precision)
1235
+ c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
1236
+ return [c1, c2], counter
1237
+
1238
+ elif n.op == 'call_function':
1239
+ if n.target in _INFERENCE_RULES:
1240
+ return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
1241
+ else:
1242
+ raise RuntimeError(f'No inference rule registered for target {n.target}!')
1243
+
1244
+ elif n.op == 'call_module':
1245
+
1246
+ module_instance = self.traced.get_submodule(n.target)
1247
+ if type(module_instance) in _INFERENCE_RULES:
1248
+ return _INFERENCE_RULES[type(module_instance)](n,
1249
+ module_instance,
1250
+ self.symbol_dict,
1251
+ self.constraints, counter)
1252
+ else:
1253
+ raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!')
1254
+
1255
+ elif n.op == 'call_method':
1256
+ if n.target in _INFERENCE_RULES:
1257
+ return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter)
1258
+ else:
1259
+ raise RuntimeError(f'No inference rule registered for target {n.target}!')
1260
+
1261
+ elif n.op == 'get_attr':
1262
+ t = self.traced_params.get(n.target, None)
1263
+
1264
+ if isinstance(t, torch.Tensor):
1265
+ if len(t.shape) > 0:
1266
+ res = list(t.shape)
1267
+ attr_type = TensorType(res)
1268
+ output, counter = gen_tvar(counter)
1269
+ self.symbol_dict[n] = output
1270
+ return [BinConstraintT(output, attr_type, op_eq)], counter
1271
+ else:
1272
+ # scalar?
1273
+ return [], counter
1274
+ else:
1275
+ return [], counter
1276
+
1277
+ elif n.op == 'output':
1278
+ return [], counter
1279
+
1280
+ else:
1281
+ raise NotImplementedError(f"Method {n.op} not yet implemented")
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: ignore-errors
2
+ import copy
3
+ import itertools
4
+ from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK
5
+ from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \
6
+ Transpose
7
+ from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound
8
+ from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound
9
+ from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool
10
+ from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape
11
+ from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect
12
+ from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching
13
+ from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq
14
+ from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod
15
+ from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar
16
+ from torch.fx.tensor_type import TensorType, Dyn
17
+ from typing import Callable, Dict, List
18
+
19
+ _TRANSFORMATION_RULES: Dict[Constraint, Callable] = {}
20
+
21
+
22
+ def register_transformation_rule(call_target):
23
+ def register(fn):
24
+ if call_target in _TRANSFORMATION_RULES:
25
+ raise RuntimeError(f'Transformation rule already registered for {call_target}!')
26
+ _TRANSFORMATION_RULES[call_target] = fn
27
+ return fn
28
+ return register
29
+
30
+
31
+ def valid_index(index, dims):
32
+ """
33
+ Given a list of dimensions, checks if an index is valid in the list
34
+ """
35
+ try:
36
+ dims[index]
37
+ return T()
38
+ except IndexError:
39
+ return F()
40
+
41
+
42
+ @register_transformation_rule(Transpose)
43
+ def transform_transpose(constraint, counter):
44
+ """
45
+ Similar to a sequence of two index-selects
46
+ """
47
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
48
+ is_valid_index1 = valid_index(constraint.index1, dims)
49
+ is_valid_index2 = valid_index(constraint.index2, dims)
50
+ new_dims = copy.deepcopy(dims)
51
+ nat_constraints = gen_nat_constraints(dims)
52
+
53
+ if is_valid_index1 == T() and is_valid_index2 == T():
54
+ new_dims[constraint.index1] = dims[constraint.index2]
55
+ new_dims[constraint.index2] = dims[constraint.index1]
56
+
57
+ transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
58
+ *nat_constraints,
59
+ is_valid_index1, is_valid_index2,
60
+ BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
61
+ return transformed_constraint, counter
62
+
63
+
64
+ @register_transformation_rule(IndexSelect)
65
+ def transform_index_select(constraint, counter):
66
+ """
67
+ The constraints consider the given tensor size, checks if the index is valid
68
+ and if so, generates a constraint for replacing the input dimension
69
+ with the required dimension
70
+ """
71
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
72
+ is_valid_index = valid_index(constraint.index, dims)
73
+ nat_constraints = gen_nat_constraints(dims)
74
+
75
+ # if the index is valid then replace the input dimension with the new dimension
76
+ # otherwise the dimension will not be replaced and the clause will contain False
77
+ if is_valid_index == T():
78
+ new_dims = copy.deepcopy(dims)
79
+ new_dims[constraint.index] = constraint.dim_replace
80
+
81
+ transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
82
+ *nat_constraints,
83
+ is_valid_index,
84
+ BinConstraintT(constraint.output, TensorType(new_dims), op_eq)])
85
+
86
+ # print(constraints)
87
+ return transformed_constraint, counter
88
+
89
+
90
+ @register_transformation_rule(GetItem)
91
+ def transform_get_item(constraint, counter):
92
+ """
93
+ generate an equality of the form:
94
+ t = [a1, ..., an]
95
+ then generate constraints that check if the given index is valid
96
+ given this particular tensor size.
97
+ If the index is valid, generate a constraint to get the item
98
+ Note that we already handled the Dyn input case in the previous
99
+ step.
100
+ Args:
101
+ constraint: GetItem which assumes we are getting an item from a tensor (not Dyn)
102
+ counter: variable tracking
103
+ Returns: simplified constraints for GetItem
104
+
105
+ """
106
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
107
+ nat_constraints = gen_nat_constraints(dims)
108
+
109
+
110
+ is_valid_index = valid_index(constraint.index, dims)
111
+
112
+ all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
113
+ *nat_constraints,
114
+ is_valid_index]
115
+
116
+ # if the index is valid, we generate a constraint for getting an item
117
+ # otherwise this clause will have been UNSAT due to the wrong index
118
+ if is_valid_index == T():
119
+ all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq))
120
+
121
+ return Conj(all_constraints), counter
122
+
123
+ def valid_index_tensor(index, dims):
124
+ """
125
+ if the slice instances exceed the length of the dimensions
126
+ then this is a type error so we return False
127
+ """
128
+ slice_count = 0
129
+ for s in index:
130
+ if isinstance(s, slice):
131
+ slice_count += 1
132
+ if slice_count > len(dims):
133
+ return F()
134
+ else:
135
+ return T()
136
+
137
+ @register_transformation_rule(GetItemTensor)
138
+ def transform_get_item_tensor(constraint, counter):
139
+ """
140
+ When the index is a tuple, then the output will be a tensor
141
+ TODO: we have to check if this is the case for all HF models
142
+
143
+ The cases we are covering here are a tuple with one of:
144
+ - slice with default argument
145
+ - None
146
+
147
+ None appends 1 to the input tensor dimensions
148
+ so each occurrence of 'None' increases the rank by 1
149
+
150
+ slice with default arguments does not change the rank
151
+ """
152
+ assert isinstance(constraint.index_tuple, tuple)
153
+
154
+
155
+ # generate a result tensor of the expected size
156
+ dims, counter = gen_tensor_dims(constraint.tensor_size, counter)
157
+ nat_constraints = gen_nat_constraints(dims)
158
+
159
+ # generate a place-holder list of the right rank
160
+ # where "slice" does not contribute to the rank and "None" does
161
+ none_c = constraint.index_tuple.count(None)
162
+ resulting_tensor_dims = (none_c + len(dims)) * [None]
163
+
164
+ dim_index = 0
165
+ for i in range(len(constraint.index_tuple)):
166
+
167
+ # append 1 to the right location of the resulting tensor
168
+ if constraint.index_tuple[i] is None:
169
+ resulting_tensor_dims[i] = 1
170
+
171
+ elif constraint.index_tuple[i] == slice(None, None, None):
172
+ pass
173
+
174
+ else:
175
+ raise NotImplementedError('Method not yet implemented')
176
+
177
+ # append the remaining dimensions to the right location
178
+ dim_index = 0
179
+ for i in range(len(resulting_tensor_dims)):
180
+ if resulting_tensor_dims[i] is None:
181
+ resulting_tensor_dims[i] = dims[dim_index]
182
+ dim_index += 1
183
+
184
+ # check if the index is valid
185
+ is_valid_index = valid_index_tensor(constraint.index_tuple, dims)
186
+
187
+ # check if the resulting tensor is within bounds
188
+ if len(resulting_tensor_dims) > 4:
189
+ return F(), counter
190
+
191
+ else:
192
+ constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq),
193
+ BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq),
194
+ *nat_constraints,
195
+ is_valid_index]
196
+ return Conj(constraints), counter
197
+
198
+
199
+ @register_transformation_rule(BinConstraintT)
200
+ def generate_binconstraint_t(constraint, counter):
201
+ """
202
+ Transform binary constraints for tensors
203
+ """
204
+
205
+ # precision constraints
206
+ if constraint.op == op_precision:
207
+ if constraint.lhs == Dyn:
208
+ return T(), counter
209
+ elif isinstance(constraint.lhs, TensorType):
210
+ is_fully_static = all(d != Dyn for d in constraint.lhs.__args__)
211
+ if is_fully_static:
212
+ return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter
213
+ else:
214
+ new_dims = []
215
+
216
+ for _ in range(len(constraint.lhs.__args__)):
217
+ dim, counter = gen_dvar(counter)
218
+ new_dims.append(dim)
219
+
220
+ new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for
221
+ new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \
222
+ [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \
223
+ [BinConstraintD(1, new_dim, op_leq) for
224
+ new_dim in new_dims]
225
+ return Conj(new_dim_constraints), counter
226
+
227
+ # matching
228
+ elif constraint.op == op_matching:
229
+ assert isinstance(constraint.rhs, TensorType)
230
+ d1 = constraint.rhs.__args__[0]
231
+ d2 = constraint.rhs.__args__[1]
232
+ d3 = constraint.rhs.__args__[2]
233
+ d4 = constraint.rhs.__args__[3]
234
+
235
+ conj = [BinConstraintT(constraint.lhs, Dyn, op_eq),
236
+ BinConstraintD(d1, Dyn, op_eq),
237
+ BinConstraintD(d2, Dyn, op_eq),
238
+ BinConstraintD(d3, Dyn, op_eq),
239
+ BinConstraintD(d4, Dyn, op_eq)]
240
+ return Disj([Conj(conj),
241
+ BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter
242
+
243
+ elif constraint.op == op_consistency:
244
+ c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)])
245
+ [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter)
246
+
247
+ return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter
248
+
249
+ elif constraint.op == op_leq:
250
+ assert isinstance(constraint.rhs, int)
251
+ disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)]
252
+ for i in range(1, constraint.rhs + 1):
253
+ dims = []
254
+ for j in range(1, i + 1):
255
+ dim_var, counter = gen_dvar(counter)
256
+ dims.append(dim_var)
257
+ disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq))
258
+ return Disj(disj), counter
259
+ else:
260
+ return constraint, counter
261
+
262
+
263
+ @register_transformation_rule(BinConstraintD)
264
+ def generate_binconstraint_d(constraint, counter):
265
+ """
266
+ Transform binary constraints for dimensions
267
+ """
268
+ if constraint.op == op_precision:
269
+ if isinstance(constraint.lhs, int):
270
+ return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter
271
+ elif constraint.lhs == Dyn:
272
+ return T(), counter
273
+
274
+ elif constraint.op == op_consistency:
275
+ return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq),
276
+ BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter
277
+
278
+ else:
279
+ return constraint, counter
280
+
281
+
282
+ @register_transformation_rule(Conj)
283
+ def generate_conj(constraint, counter):
284
+ """
285
+ Transform conjunctions
286
+ """
287
+ new = []
288
+ for c in constraint.conjucts:
289
+ new_c, counter = transform_constraint(c, counter)
290
+ new.append(new_c)
291
+ return Conj(new), counter
292
+
293
+
294
+ @register_transformation_rule(Disj)
295
+ def generate_disj(constraint, counter):
296
+ """
297
+ Transform disjunctions
298
+ """
299
+ new = []
300
+ for c in constraint.disjuncts:
301
+ new_c, counter = transform_constraint(c, counter)
302
+ new.append(new_c)
303
+ return Disj(new), counter
304
+
305
+
306
+ @register_transformation_rule(TGreatestUpperBound)
307
+ def generate_gub(constraint, counter):
308
+ """
309
+ Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound
310
+ on dimensions
311
+ """
312
+ c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq),
313
+ BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)])
314
+
315
+ [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter)
316
+
317
+ return Disj([c1, c2, c3, c4, c5]), counter
318
+
319
+
320
+ @register_transformation_rule(DGreatestUpperBound)
321
+ def generate_d_gub(constraint, counter):
322
+ """
323
+ Transform greatest upper bound for dimensions into equality constraints
324
+ """
325
+ c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)])
326
+ c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
327
+ c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)])
328
+ return Disj([c1, c2, c3]), counter
329
+
330
+
331
+ @register_transformation_rule(CalcConv)
332
+ def generate_calc_conv(constraint, counter):
333
+ d, counter = gen_tensor_dims(4, counter)
334
+ conv_result = TensorType([d[0], d[1], d[2], d[3]])
335
+
336
+ # the convolution result is a tensor of size 4
337
+ c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq)
338
+
339
+ # the second dimension of the output is equal to the output channels
340
+ c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)])
341
+
342
+ # the input corresponds to the output in the first dimension of the convolution
343
+ c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
344
+
345
+ c4, c5 = calc_last_two_dims(constraint, d)
346
+
347
+ leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
348
+ BinConstraintD(0, d[1], op_leq),
349
+ BinConstraintD(0, d[2], op_leq),
350
+ BinConstraintD(0, d[3], op_leq)])
351
+
352
+ return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
353
+
354
+
355
+ @register_transformation_rule(CalcMaxPool)
356
+ def generate_calc_maxpool(constraint, counter):
357
+ """
358
+ Transform maxpool constraints
359
+ """
360
+ d, counter = gen_tensor_dims(4, counter)
361
+ maxpool_result = TensorType([d[0], d[1], d[2], d[3]])
362
+
363
+ # the maxpool result is a tensor of size 4
364
+ c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq)
365
+
366
+ # the input corresponds to the output in the first and second dimension of maxpool
367
+ c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq)
368
+ c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq)
369
+ c4, c5 = calc_last_two_dims(constraint, d)
370
+
371
+ leq_constraints = Conj([BinConstraintD(0, d[0], op_leq),
372
+ BinConstraintD(0, d[1], op_leq),
373
+ BinConstraintD(0, d[2], op_leq),
374
+ BinConstraintD(0, d[3], op_leq)])
375
+
376
+ return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
377
+
378
+
379
+ @register_transformation_rule(CalcProduct)
380
+ def generate_calc_product(constraint, counter):
381
+ """
382
+ Transform flatten constraints
383
+ """
384
+ start = constraint.start
385
+ end = constraint.end
386
+ dims = constraint.dims_to_flatten
387
+ flattened = constraint.flattened
388
+ n = len(constraint.dims_to_flatten)
389
+
390
+ # this will be evaluated right here
391
+ boundary_check = (0 <= start and start < end and end <= n)
392
+
393
+ c_boundary = T() if boundary_check else F()
394
+
395
+ lhs = dims[0:start]
396
+ rhs = dims[end:]
397
+ mid = dims[start:end]
398
+
399
+ all_possibilities = generate_all_int_dyn_dim_possibilities(mid)
400
+
401
+ all_constraints = []
402
+
403
+ for p in all_possibilities:
404
+ p = list(p)
405
+ # this tells us there is a dynamic variable
406
+ contains_dyn = not all(constraint.op == op_neq for constraint in p)
407
+ if contains_dyn:
408
+ mid_var = [Dyn]
409
+ total_constraints = lhs + mid_var + rhs
410
+ if len(total_constraints) > 4:
411
+ all_constraints.append(F())
412
+ else:
413
+ all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p))
414
+ else:
415
+ new_var, counter = gen_dvar(counter)
416
+ mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)])
417
+ mid_var = [new_var]
418
+ total_constraints = lhs + mid_var + rhs
419
+ if len(total_constraints) > 4:
420
+ all_constraints.append(F())
421
+ else:
422
+ all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p))
423
+
424
+ return Conj([Disj(all_constraints), c_boundary]), counter
425
+
426
+
427
+ @register_transformation_rule(CanReshape)
428
+ def generate_reshape(constraint, counter):
429
+ """
430
+ Transform reshape constraints
431
+ """
432
+ d, counter = gen_tensor_dims(4, counter)
433
+
434
+ d1 = d[0]
435
+ d2 = d[1]
436
+ d3 = d[2]
437
+ d4 = d[3]
438
+
439
+ target = constraint.target.__args__
440
+
441
+ is_fully_static = all(d != Dyn for d in target)
442
+
443
+ # dynamic tensor
444
+ c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq)
445
+ c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq)
446
+ c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq)
447
+ c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq)
448
+ c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq)
449
+
450
+ d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq)
451
+ d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq)
452
+
453
+ d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq)
454
+ d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq)
455
+
456
+ d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
457
+ d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
458
+
459
+ d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq)
460
+ d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq)
461
+
462
+ nat_d1 = BinConstraintD(0, d1, op_leq)
463
+ nat_d2 = BinConstraintD(0, d2, op_leq)
464
+ nat_d3 = BinConstraintD(0, d3, op_leq)
465
+ nat_d4 = BinConstraintD(0, d4, op_leq)
466
+
467
+ if is_fully_static:
468
+ # size 1 tensor
469
+ c3_tensor1 = Disj([d1_eq_dyn,
470
+ (Conj([d1_neq_dyn,
471
+ BinConstraintD(d1, Prod(target), op_eq)]))])
472
+ all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
473
+
474
+ # size 2 tensor
475
+ all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)])
476
+
477
+ # size 3 tensor
478
+ all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)])
479
+
480
+ # size 4 tensor
481
+ all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)])
482
+
483
+ return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
484
+ nat_d1, nat_d2, nat_d3, nat_d4]), counter
485
+
486
+ # then there must be exactly one occurrence of dyn
487
+ else:
488
+ new_target = []
489
+
490
+ for n in target:
491
+ if n != Dyn:
492
+ new_target.append(n)
493
+
494
+ # tensor 1
495
+ c3_tensor1 = Disj([d1_eq_dyn,
496
+ (Conj([d1_neq_dyn,
497
+ is_dim_div_by_target(new_target, d1)]))])
498
+ all_tensor_1 = Conj([c2_tensor1, c3_tensor1])
499
+
500
+ # tensor 2
501
+ c21 = Disj([d1_eq_dyn, d2_eq_dyn])
502
+ c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))])
503
+ all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])])
504
+
505
+ # tensor 3
506
+ c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn])
507
+ c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))])
508
+ all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])])
509
+
510
+ # tensor 4
511
+ c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn])
512
+ c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))])
513
+ all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])])
514
+
515
+ return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]),
516
+ nat_d1, nat_d2, nat_d3, nat_d4]), counter
517
+
518
+
519
+ @register_transformation_rule(ApplyBroadcasting)
520
+ def generate_broadcasting(constraint, counter):
521
+ """
522
+ Transform broadcasting constraints
523
+ """
524
+ e11, e12 = constraint.res1, constraint.res2
525
+ e1, e2 = constraint.input1, constraint.input2
526
+
527
+ e1_dyn = BinConstraintT(e1, Dyn, op_eq)
528
+ e2_dyn = BinConstraintT(e2, Dyn, op_eq)
529
+
530
+ # Introduce dimensions
531
+ e1_equal_e11 = BinConstraintT(e1, e11, op_eq)
532
+ e2_equal_e12 = BinConstraintT(e2, e12, op_eq)
533
+
534
+ # dyn possibility
535
+ e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12])
536
+ e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12])
537
+
538
+ # tensor possibility
539
+ # generate dimensions to create tensors of size 1
540
+ final_tensor_1_constraint, _, _, nat_dims_1, counter = \
541
+ gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter)
542
+
543
+ # generate dimensions to create tensors of size 2
544
+ final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \
545
+ final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \
546
+ gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter)
547
+
548
+ # generate dimensions to create tensors of size 3
549
+ final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \
550
+ final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \
551
+ gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter)
552
+
553
+ # generate dimensions to create tensors of size 4
554
+ final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \
555
+ final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \
556
+ gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter)
557
+
558
+ final_result = Disj([
559
+ e1_dyn_constraint,
560
+ e2_dyn_constraint,
561
+ final_tensor_1_constraint,
562
+ final_tensor_2_constraint_no_padding,
563
+ final_tensor_2_constraint_padding_arg1,
564
+ final_tensor_2_constraint_padding_arg2,
565
+ final_tensor_3_constraint_no_padding,
566
+ final_tensor_3_constraint_padding_arg1,
567
+ final_tensor_3_constraint_padding_arg2,
568
+ final_tensor_4_constraint_no_padding,
569
+ final_tensor_4_constraint_padding_arg1,
570
+ final_tensor_4_constraint_padding_arg2
571
+ ])
572
+
573
+ return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
574
+
575
+
576
+ def transform_constraint(constraint: Constraint, counter: int):
577
+ """
578
+ Transforms a constraint into a simpler constraint.
579
+ Ex: precision and consistency are transformed to equality
580
+ Args:
581
+ constraint: constraint to be transformed
582
+ counter: for variable tracking
583
+
584
+ Returns: Constraint
585
+
586
+ """
587
+ if type(constraint) in _TRANSFORMATION_RULES:
588
+ return _TRANSFORMATION_RULES[type(constraint)](constraint, counter)
589
+
590
+ else:
591
+ return constraint, counter
592
+
593
+
594
+
595
+
596
+ def calc_last_two_dims(constraint, d: List[DVar]):
597
+ """
598
+ Generates constraints for the last two dimensions of a convolution or a maxpool output
599
+ Args:
600
+ constraint: CalcConv or CalcMaxPool
601
+ d: The list of output dimensions
602
+
603
+ Returns: Constraints for calculating the last two dimensions of the output
604
+
605
+ """
606
+
607
+ assert isinstance(constraint, (CalcConv, CalcMaxPool))
608
+
609
+ b3 = constraint.matching_constraint[2]
610
+ b4 = constraint.matching_constraint[3]
611
+
612
+ b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)])
613
+ b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)])
614
+
615
+ d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)])
616
+ d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)])
617
+
618
+ # transform parameters into tuples incase they are not already
619
+ padding = (constraint.padding, constraint.padding) \
620
+ if isinstance(constraint.padding, int) else constraint.padding
621
+ kernel = (constraint.kernel, constraint.kernel) \
622
+ if isinstance(constraint.kernel, int) else constraint.kernel
623
+ stride = (constraint.stride, constraint.stride) \
624
+ if isinstance(constraint.stride, int) else constraint.stride
625
+ dilation = (constraint.dilation, constraint.dilation) \
626
+ if isinstance(constraint.dilation, int) else constraint.dilation
627
+
628
+ f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add)
629
+ f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul)
630
+ f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div)
631
+ f4 = BinConstraintD(f3, 1, op_add)
632
+
633
+ c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])])
634
+
635
+ f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add)
636
+ f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul)
637
+ f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div)
638
+ f44 = BinConstraintD(f33, 1, op_add)
639
+
640
+ c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])])
641
+
642
+ return c4, c5
643
+
644
+
645
+ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
646
+ """
647
+ Generate all possibilities of being equal or not equal to dyn for my_list
648
+ Args:
649
+ my_list: List of tensor dimensions
650
+
651
+ Returns: A list of a list of constraints. Each list of constraints corresponds to
652
+ one possibility about the values of the dimension variables
653
+ """
654
+ # generate all possibilities of being equal or not equal to dyn for my_list
655
+ eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))]
656
+ neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))]
657
+ d_possibilities = []
658
+
659
+ for i in zip(eq_possibilities, neq_possibilities):
660
+ d_possibilities.append(list(i))
661
+ all_possibilities = list(itertools.product(*d_possibilities))
662
+ return all_possibilities
663
+
664
+
665
+ def is_target_div_by_dim(target: List[int], dim: List[DVar]):
666
+ """
667
+ Generate constraints to check if the target dimensions are divisible by the input dimensions
668
+ Args:
669
+ target: Target dimensions
670
+ dim: Input dimensions
671
+
672
+ Returns: Constraints to check divisibility
673
+
674
+ """
675
+ return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq)
676
+
677
+
678
+ def is_dim_div_by_target(target: List[int], dim: List[DVar]):
679
+ """
680
+ Generate constraints to check if the input dimensions is divisible by the target dimensions
681
+ Args:
682
+ target: Target dimensions
683
+ dim: Input dimensions
684
+
685
+ Returns: Constraints to check divisibility
686
+
687
+ """
688
+ return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq)
689
+
690
+
691
+ def gen_all_reshape_possibilities(list_of_dims, target):
692
+ """
693
+ Consider all possibilities what the input dimensions could be (number or dynamic)
694
+ Then generate the appropriate constraints using multiplication or mod depending on the possibility
695
+ The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn
696
+ for the input. Target is fixed because at most one dimension could be dyn.
697
+ We have different cases for this.
698
+
699
+ Args:
700
+ list_of_dims: The input list of dimensions
701
+ target: The tensor we want to reshape to
702
+
703
+ Returns: A disjunction of transformed reshape constraints
704
+
705
+ """
706
+ all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims)
707
+
708
+ all_constraints = []
709
+
710
+ for p in all_possibilities:
711
+ to_multiply = []
712
+
713
+ p = list(p)
714
+
715
+ for constraint in p:
716
+ assert isinstance(constraint, BinConstraintD)
717
+ if constraint.op == op_neq:
718
+ to_multiply.append(constraint.lhs)
719
+
720
+ if not to_multiply:
721
+ all_constraints.append(Conj(p))
722
+
723
+ elif len(to_multiply) < len(list_of_dims):
724
+ all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]))
725
+ else:
726
+ all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims),
727
+ Prod(target), op_eq)]))
728
+
729
+ return Disj(all_constraints)
730
+
731
+
732
+ def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False):
733
+ """
734
+ Apply broadcasting to the 'index' dimension of tensor_input1.
735
+ Args:
736
+ tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1
737
+ tensor_input2: represents the second input
738
+ res1: broadcasted result 1
739
+ res2: broadcasted result 2
740
+ index: the index to broadcast
741
+ padding: If padding was used, then tensor_input1[index] does not exist
742
+
743
+ Returns:
744
+
745
+ """
746
+ if tensor_input1[index] is None:
747
+ assert padding
748
+
749
+
750
+ if not padding:
751
+ # then the inputs are the same length so they all have dimensions at "index"
752
+ return Conj([BinConstraintD(tensor_input1[index], 1, op_eq),
753
+ BinConstraintD(res1[index], res2[index], op_eq),
754
+ BinConstraintD(res2[index], tensor_input2[index], op_eq)])
755
+
756
+ else:
757
+ # we don't set the input dimension to 1, since it doesn't exist.
758
+ return Conj([BinConstraintD(res1[index], res2[index], op_eq),
759
+ BinConstraintD(res2[index], tensor_input2[index], op_eq)])
760
+
761
+
762
+ def apply_padding(e1_var: TVar,
763
+ e11: BinConstraintT,
764
+ e2: BinConstraintT,
765
+ e12: BinConstraintT,
766
+ d2: List[DVar],
767
+ d11: List[DVar],
768
+ d12: List[DVar],
769
+ counter: int):
770
+ """
771
+ We are considering the possibility where one input has less dimensions than
772
+ another input, so we apply padding to the broadcasted results
773
+
774
+ Args:
775
+ e1_var: Variable representing the first input where padding will be
776
+ e11: constraint of the form e11 = Tensortype[d1, ..., dn]
777
+ e2: constraint of the form e2 = Tensortype[d1, ..., dn]
778
+ e12: constraint of the form e11 = Tensortype[d1, ..., dn]
779
+ d2: Tensor variables for the second input
780
+ d11: Tensor variables for the broadcasted first input
781
+ d12: Tensor variables for the broadcasted second input
782
+ counter: variable tracking
783
+
784
+ Returns: A new constraint whose goal is to apply padding to the broadcasted result
785
+
786
+ """
787
+
788
+ res = []
789
+
790
+ # pad the shorter input with None so we can pass it to the broadcasting helper function
791
+ for i in range(1, len(d2)):
792
+
793
+ d1, counter = gen_tensor_dims(i, counter)
794
+
795
+ nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12)
796
+
797
+ e1 = BinConstraintT(e1_var, TensorType(d1), op_eq)
798
+
799
+ simulate_padding = [None] * (len(d2) - i)
800
+
801
+ assert len(simulate_padding + d1) == len(d2)
802
+
803
+ broadcast_padding = []
804
+
805
+ # for every padding size, we also consider broadcasting
806
+ for j in range(len(d2) - i):
807
+ broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True))
808
+
809
+ # we consider the possibilities for broadcasting for every dimension. Since we already
810
+ # padded d1, we do not consider it while broadcasting
811
+ all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1,
812
+ d2[(len(d2) - i):],
813
+ d11[(len(d2) - i):],
814
+ d12[(len(d2) - i):])
815
+ # combine all constraints into a conjunction
816
+ c = Conj([e1, e11, e2, e12,
817
+ *broadcast_padding,
818
+ all_broadcasting_possibilities,
819
+ *nat_constraints
820
+ ])
821
+ res.append(c)
822
+
823
+ return Disj(res), counter
824
+
825
+
826
+ def no_broadcast_dim_with_index(d1: List[DVar],
827
+ d2: List[DVar],
828
+ d3: List[DVar],
829
+ d4: List[DVar],
830
+ i: int):
831
+ """
832
+ Args:
833
+ d1: input 1
834
+ d2: input 2
835
+ d3: simulated broadcasting for input 1
836
+ d4: simulated broadcasting for input 2
837
+ i: the rank of the resulting tensor addition
838
+
839
+ Returns: Constraints for when no broadcasting occurs
840
+ """
841
+ return Conj([
842
+ Disj([
843
+ Conj([BinConstraintD(d1[i], 1, op_eq),
844
+ BinConstraintD(d2[i], 1, op_eq)]),
845
+
846
+ Conj([BinConstraintD(d1[i], 1, op_neq),
847
+ BinConstraintD(d2[i], 1, op_neq)])]),
848
+
849
+ BinConstraintD(d1[i], d3[i], op_eq),
850
+ BinConstraintD(d2[i], d4[i], op_eq)])
851
+
852
+
853
+
854
+ def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int):
855
+ """
856
+ Generate lists of DVar to represent tensor dimensions
857
+ Args:
858
+ num_tensors: the required number of tensors
859
+ dim_size: the number of dimensions for each tensor
860
+ counter: variable tracking
861
+
862
+ Returns: A list of a list of tensor dimensions
863
+
864
+ """
865
+ res = []
866
+
867
+ for _ in range(num_tensors):
868
+ dims, counter = gen_tensor_dims(dim_size, counter)
869
+ res.append(dims)
870
+
871
+ return res, counter
872
+
873
+
874
+ def create_equality_constraints_for_broadcasting(e1: TVar,
875
+ e2: TVar,
876
+ e11: TVar,
877
+ e12: TVar,
878
+ d1: List[DVar],
879
+ d2: List[DVar],
880
+ d11: List[DVar],
881
+ d12: List[DVar]):
882
+ """
883
+ Create equality constraints for when no broadcasting occurs
884
+ Args:
885
+ e1: Input 1
886
+ e2: Input 2
887
+ e11: Broadcasted input 1
888
+ e12: Broadcasted input 2
889
+ d1: Variables that store dimensions for e1
890
+ d2: Variables that store dimensions for e2
891
+ d11: Variables that store dimensions for e11
892
+ d12: Variables that store dimensions for e22
893
+
894
+ Returns: Four equality constraints
895
+
896
+ """
897
+
898
+ e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq)
899
+ e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq)
900
+ e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq)
901
+ e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq)
902
+ return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
903
+
904
+
905
+ def gen_consistency_constraints(constraint: Constraint, counter: int):
906
+ """
907
+ Args:
908
+ constraint: Consistency constraint on tensors
909
+ counter: for variable tracking
910
+
911
+ Returns: Equality and consistency constraints on dimensions
912
+
913
+ """
914
+
915
+ all_constraints = []
916
+
917
+ for i in range(1, MAX_TENSOR_RANK + 1):
918
+ new_dims_rhs_1, counter = gen_tensor_dims(i, counter)
919
+ new_dims_rhs_2, counter = gen_tensor_dims(i, counter)
920
+
921
+ nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2)
922
+
923
+ c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq),
924
+ BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] +
925
+ [BinConstraintD(d1, d2, op_consistency) for
926
+ d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints)
927
+
928
+ all_constraints.append(c_tensor_i)
929
+
930
+ return all_constraints, counter
931
+
932
+
933
+ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int):
934
+ """
935
+ Args:
936
+ constraint: Greatest upper bound on tensors
937
+ counter: variable tracking
938
+
939
+ Returns: A set of equality constraints and DGreatestUpperBound constraints
940
+
941
+ """
942
+
943
+ all_constraints = []
944
+
945
+ for i in range(1, MAX_TENSOR_RANK + 1):
946
+ c = []
947
+ dims1, counter = gen_tensor_dims(i, counter)
948
+ c1tensor = TensorType(dims1)
949
+
950
+ dims2, counter = gen_tensor_dims(i, counter)
951
+ c2tensor = TensorType(dims2)
952
+
953
+ dims3, counter = gen_tensor_dims(i, counter)
954
+ c3tensor = TensorType(dims3)
955
+
956
+ c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq),
957
+ BinConstraintT(constraint.rhs2, c2tensor, op_eq),
958
+ BinConstraintT(constraint.res, c3tensor, op_eq)] + \
959
+ gen_nat_constraints(dims1 + dims2 + dims3)
960
+
961
+ assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__)
962
+ for i in range(len(c3tensor.__args__)):
963
+ c.append(DGreatestUpperBound(c3tensor.__args__[i],
964
+ c1tensor.__args__[i],
965
+ c2tensor.__args__[i]))
966
+
967
+ all_constraints.append(Conj(c))
968
+ return all_constraints, counter
969
+
970
+
971
+ def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]):
972
+ """
973
+ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension.
974
+ We look at all combinations for all dimensions in d1 and d2
975
+ Args:
976
+ d1: input1 dimensions
977
+ d2: input2 dimensions
978
+ d11: broadcasted input1 dimensions
979
+ d12: broadcasted input2 dimensions
980
+
981
+ Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions
982
+
983
+ """
984
+
985
+ size = len(d1)
986
+
987
+ res2 = []
988
+
989
+ for i in range(size):
990
+ t1 = broadcast_dim(d1, d2, d11, d12, i)
991
+ t2 = broadcast_dim(d2, d1, d12, d11, i)
992
+ t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i)
993
+
994
+ res2.append(Disj([t1, t2, t3]))
995
+
996
+ return Conj(res2)
997
+
998
+
999
+ def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int):
1000
+ """
1001
+ Simulates broadcasting on e1 and e2 and returns the results
1002
+ respectively in e11 and e12. Because of gradual types,
1003
+ e1 and e2 may not be equal. Similarly, e11 and e12 may not
1004
+ be equal. e11 and e12 should be guaranteed to be consistent
1005
+ as they represent the shapes of the tensors to be added after
1006
+ broadcasting.
1007
+ Args:
1008
+ e1: TVar representing the type of input 1
1009
+ e2: TVar representing the type of input 2
1010
+ e11: TVar representing the representing broadcasted input 1
1011
+ e12: TVar representing the representing broadcasted input 2
1012
+ i: The rank of the resulting type of addition
1013
+ counter: for variable tracking
1014
+
1015
+ Returns: Simplified broadcasting constraints
1016
+
1017
+ """
1018
+ dims, counter = gen_lists_of_dims(4, i, counter)
1019
+ [d1, d2, d3, d4] = dims
1020
+ nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims)))
1021
+
1022
+ initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12,
1023
+ d1, d2, d3, d4)
1024
+
1025
+ [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints
1026
+
1027
+ # without padding, broadcast all possibilities for tensors of size i
1028
+ final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints,
1029
+ generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)])
1030
+
1031
+ # with padding, broadcast all possibilities for tensors of size i
1032
+ final_tensor_constraint_padding_arg1, counter = \
1033
+ apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter)
1034
+
1035
+ final_tensor_constraint_padding_arg2, counter = \
1036
+ apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter)
1037
+
1038
+ return final_tensor_constraint_no_padding, \
1039
+ final_tensor_constraint_padding_arg1, \
1040
+ final_tensor_constraint_padding_arg2, nat_dims_i, counter
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ op_add = '+'
2
+ op_sub = '-'
3
+ op_mul = '*'
4
+ op_div = '/'
5
+ op_eq = '='
6
+ op_neq = '!='
7
+ op_imp = '=>'
8
+ op_matching = '\u22b3' # (contains)
9
+ op_consistency = '~'
10
+ op_precision = '\u2291' # (square image of or equal to)
11
+ op_leq = '\u2264' # less-than or equal to
12
+ op_lt = '<'
13
+ op_gt = '>'
14
+ op_mod = '%'
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr
3
+ from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar
4
+ from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim
5
+ from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator
6
+ from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint
7
+ from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt
8
+ from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod
9
+ from torch.fx.tensor_type import TensorType, Dyn
10
+
11
+ try:
12
+ import z3 # type: ignore[import]
13
+ from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D
14
+ HAS_Z3 = True
15
+
16
+ def transform_to_z3(constraint, counter, dimension_dict):
17
+ if isinstance(constraint, Conj):
18
+ conjuncts = []
19
+ for c in constraint.conjucts:
20
+ new_c, counter = transform_to_z3(c, counter, dimension_dict)
21
+ conjuncts.append(new_c)
22
+ return z3.And(conjuncts), counter
23
+
24
+ elif isinstance(constraint, Disj):
25
+ disjuncts = []
26
+ for c in constraint.disjuncts:
27
+ new_c, counter = transform_to_z3(c, counter, dimension_dict)
28
+ disjuncts.append(new_c)
29
+ return z3.Or(disjuncts), counter
30
+
31
+ elif isinstance(constraint, T):
32
+ return True, counter
33
+
34
+ elif isinstance(constraint, F):
35
+ return False, counter
36
+
37
+ elif isinstance(constraint, BinConstraintT):
38
+ if constraint.op == op_eq:
39
+ lhs, counter = transform_var(constraint.lhs, counter, dimension_dict)
40
+ rhs, counter = transform_var(constraint.rhs, counter, dimension_dict)
41
+ return (lhs == rhs), counter
42
+
43
+ else:
44
+ raise NotImplementedError('Method not yet implemented')
45
+
46
+ elif isinstance(constraint, BinConstraintD):
47
+ if constraint.op == op_eq:
48
+
49
+ if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs):
50
+ transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict)
51
+ transformed_lhs = z3.Bool(constraint.lhs.c)
52
+ return transformed_lhs == transformed_rhs, counter
53
+
54
+ elif is_dim(constraint.lhs) and is_dim(constraint.rhs):
55
+ # with dimension transformations we consider the encoding
56
+ lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
57
+ rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
58
+ return lhs == rhs, counter
59
+
60
+ else:
61
+ # then we have an algebraic expression which means that we disregard the
62
+ # first element of the encoding
63
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
64
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
65
+ return lhs == rhs, counter
66
+
67
+ # The assumption here is that the LHS and RHS must be dimensions
68
+ elif constraint.op == op_neq:
69
+ assert is_dim(constraint.lhs)
70
+ assert is_dim(constraint.rhs)
71
+ lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict)
72
+ rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict)
73
+ if constraint.rhs == Dyn or constraint.lhs == Dyn:
74
+ if constraint.rhs == Dyn:
75
+ return lhs.arg(0) == 1, counter
76
+ elif constraint.lhs == Dyn:
77
+ return rhs.arg(0) == 1, counter
78
+
79
+ # if one of the instances is a number
80
+ elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int):
81
+ if isinstance(constraint.lhs, int):
82
+ return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
83
+
84
+ elif isinstance(constraint.rhs, int):
85
+ return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter
86
+
87
+ else:
88
+ return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]),
89
+ z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]),
90
+ z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter
91
+
92
+
93
+ elif constraint.op == op_leq:
94
+ # if the dimensions are not dyn, this will come into effect
95
+ # there would have been another constraint specifying if a given dimension
96
+ # is dyn or not
97
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
98
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
99
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
100
+ return lhs <= rhs, counter
101
+
102
+ elif constraint.op == op_gt:
103
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
104
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
105
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
106
+ return lhs > rhs, counter
107
+
108
+ elif constraint.op == op_lt:
109
+ assert is_dim(constraint.lhs) and is_dim(constraint.rhs)
110
+ lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict)
111
+ rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict)
112
+ return lhs < rhs, counter
113
+
114
+ else:
115
+ raise NotImplementedError('operation not yet implemented')
116
+
117
+ else:
118
+ raise NotImplementedError('Operation not yet implemented')
119
+
120
+
121
+ def transform_var(tensor, counter, dimension_dict):
122
+ """
123
+ Transforms tensor variables to a format understood by z3
124
+ Args:
125
+ tensor: Tensor variable or a tensor type potentially with variable dimensions
126
+ Returns: Transformed variable to a z3 format
127
+
128
+ """
129
+ if isinstance(tensor, TensorType):
130
+ res = []
131
+ for t in tensor.__args__:
132
+ transformed, counter = transform_dimension(t, counter, dimension_dict)
133
+ res.append(transformed)
134
+
135
+ assert len(res) <= 4
136
+ if len(tensor.__args__) == 1:
137
+ return tensor_type.tensor1(res[0]), counter
138
+ elif len(tensor.__args__) == 2:
139
+ return tensor_type.tensor2(res[0], res[1]), counter
140
+ elif len(tensor.__args__) == 3:
141
+ return tensor_type.tensor3(res[0], res[1], res[2]), counter
142
+ elif len(tensor.__args__) == 4:
143
+ return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter
144
+
145
+ elif tensor == Dyn:
146
+ return z3_dyn, counter
147
+
148
+ elif isinstance(tensor, TVar):
149
+ return z3.Const(tensor.tvar, tensor_type), counter
150
+
151
+ def transform_dimension(dimension, counter, dimension_dict):
152
+ """
153
+ Takes a dimension variable or a number and transforms it to a tuple
154
+ according to our scheme
155
+ Args:
156
+ dimension: The dimension to be transformed
157
+ counter: variable tracking
158
+
159
+ Returns: tuple and the current counter
160
+
161
+ """
162
+ if dimension == Dyn:
163
+ counter += 1
164
+ return D(0, z3.Int(counter)), counter
165
+ elif isinstance(dimension, int):
166
+ return D(1, dimension), counter
167
+ elif isinstance(dimension, DVar):
168
+ if dimension.c in dimension_dict:
169
+ return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter
170
+ else:
171
+ counter += 1
172
+ dimension_dict[dimension.c] = counter
173
+ return D(z3.Int(counter), z3.Int(dimension.c)), counter
174
+
175
+
176
+ def transform_algebraic_expression(expr, counter, dimension_dict):
177
+ """
178
+ Transforms an algebraic expression to z3 format
179
+ Args:
180
+ expr: An expression is either a dimension variable or an algebraic-expression
181
+
182
+
183
+ Returns: the transformed expression
184
+
185
+ """
186
+ assert is_algebraic_expression(expr) or is_dim(expr)
187
+
188
+ if is_dim(expr):
189
+ transformed, counter = transform_dimension(expr, counter, dimension_dict)
190
+ return transformed.arg(1), counter
191
+
192
+ elif isinstance(expr, Prod):
193
+
194
+ dims = []
195
+ for dim in expr.products:
196
+ assert is_dim(dim)
197
+ d, counter = transform_dimension(dim, counter, dimension_dict)
198
+ dims.append(d.arg(1))
199
+ return z3.Product(dims), counter
200
+
201
+ elif is_algebraic_expression(expr):
202
+
203
+ lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict)
204
+ rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict)
205
+
206
+ if expr.op == op_sub:
207
+ c = lhs - rhs
208
+
209
+ elif expr.op == op_add:
210
+ c = lhs + rhs
211
+
212
+ elif expr.op == op_div:
213
+ c = lhs / rhs
214
+
215
+ elif expr.op == op_mul:
216
+ c = lhs * rhs
217
+
218
+ elif expr.op == op_mod:
219
+ c = lhs % rhs
220
+
221
+ else:
222
+ raise NotImplementedError('operation not yet implemented')
223
+
224
+ return c, counter
225
+
226
+ else:
227
+ raise RuntimeError
228
+
229
+
230
+ def transform_all_constraints(traced, counter=0):
231
+ """
232
+ Given a trace, generates constraints and transforms them to z3 format
233
+
234
+ """
235
+ dimension_dict = {} # type: ignore[var-annotated]
236
+
237
+ generator = ConstraintGenerator(traced)
238
+ new_constraints, counter = generator.generate_constraints(counter)
239
+
240
+ # print(new_constraints.conjucts[0])
241
+ # print(*new_constraints.conjucts, sep='\n')
242
+
243
+ # transform precision, matching, consistency till obtaining a fixed point
244
+ new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
245
+ # print(new_constraints)
246
+ # print(new_constraints.conjucts)
247
+ # new_constraints.conjucts = new_constraints.conjucts[:-1]
248
+ # print(*new_constraints.conjucts, sep='\n')
249
+
250
+ transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
251
+ # print(transformed)
252
+ return transformed
253
+
254
+ def iterate_till_fixed_point(constraints, counter):
255
+ """
256
+ Transform constraints till reaching a fixed point
257
+ """
258
+ old_c = None
259
+ while old_c != constraints:
260
+ old_c = constraints
261
+ constraints, counter = transform_constraint(constraints, counter)
262
+ return constraints, counter
263
+
264
+ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0):
265
+ """
266
+ Takes a node and a graph and generates two sets of constraints.
267
+ One set constraints the node's constraints and another set
268
+ constraints the negation of the node's constraints
269
+ Args:
270
+ tracer_root: the root for getting the module instances
271
+ graph: the graph so far in the tracing process
272
+ node: node that represents a conditional
273
+ counter: variable tracking
274
+
275
+ Returns: Two sets of constraints. One with a conjunction with the
276
+ the conditional constraint and the other with a conjunction with
277
+ its negation.
278
+
279
+ """
280
+ dimension_dict = {} # type: ignore[var-annotated]
281
+
282
+ generator = ConstraintGenerator(tracer_root, graph)
283
+ new_constraints, counter = generator.generate_constraints(counter)
284
+
285
+ condition_constraint = new_constraints.conjucts[-1]
286
+
287
+ # we know the constraint is a conjunction where the last constraint is about the conditional
288
+ # so remove the last constraint
289
+ new_constraints.conjucts = new_constraints.conjucts[:-1]
290
+
291
+ # transform precision, matching, consistency till obtaining a fixed point
292
+ new_constraints, counter = iterate_till_fixed_point(new_constraints, counter)
293
+
294
+
295
+ # since the function returns a list of one element, we get the first element
296
+ # we are only interested in the RHS in this case because the LHS just stores
297
+ # the result
298
+
299
+ # we make sure the constraint is of the form:
300
+ # c = b where b is a boolean expression
301
+ # and we consider b (constraint.rhs) for transformation
302
+ assert isinstance(condition_constraint.lhs, BVar)
303
+ assert is_bool_expr(condition_constraint.rhs)
304
+ condition_constraint_rhs = condition_constraint.rhs
305
+
306
+ # transform the condition constraint
307
+ condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter)
308
+
309
+ transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict)
310
+
311
+ transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict)
312
+
313
+ negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint)
314
+
315
+ return z3.And([transformed, transformed_condition_constraint]), \
316
+ z3.And([transformed, negation_transformed_condition_constraint])
317
+
318
+
319
+ def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None):
320
+ """
321
+ Given an IR and a node representing a conditional, evaluate the conditional
322
+ and its negation
323
+ Args:
324
+ tracer_root: Tracer root for module instances
325
+ node: The node to be evaluated
326
+
327
+ Returns: the results of evaluating the condition and the negation with
328
+ the rest of the constraints
329
+
330
+ """
331
+
332
+ transformed_positive, transformed_negative = \
333
+ transform_all_constraints_trace_time(tracer_root, graph, node, counter)
334
+
335
+ s = z3.Solver()
336
+ s.add(transformed_positive)
337
+ if user_constraints is not None:
338
+ s.add(user_constraints)
339
+ condition = s.check()
340
+
341
+ s = z3.Solver()
342
+ s.add(transformed_negative)
343
+ if user_constraints is not None:
344
+ s.add(user_constraints)
345
+ negation = s.check()
346
+ return condition, negation
347
+
348
+ except ImportError:
349
+ HAS_Z3 = False
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/util.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \
3
+ BVar
4
+ from torch.fx.experimental.migrate_gradual_types.operation import op_leq
5
+
6
+
7
+ def gen_tvar(curr):
8
+ """
9
+ Generate a tensor variable
10
+ :param curr: The current counter
11
+ :return: a tensor variable and the updated counter
12
+ """
13
+ curr += 1
14
+ return TVar(curr), curr
15
+
16
+
17
+ def gen_dvar(curr):
18
+ """
19
+ Generate a dimension variable
20
+ :param curr: the current counter
21
+ :return: a dimension variable and an updated counter
22
+ """
23
+ curr += 1
24
+ return DVar(curr), curr
25
+
26
+ def gen_bvar(curr):
27
+ """
28
+ Generate a boolean variable
29
+ :param curr: the current counter
30
+ :return: a boolean variable and an updated counter
31
+ """
32
+ curr += 1
33
+ return BVar(curr), curr
34
+
35
+ def gen_tensor_dims(n, curr):
36
+ """
37
+ Generate a list of tensor dimensions
38
+ :param n: the number of dimensions
39
+ :param curr: the current counter
40
+ :return: a list of dimension variables and an updated counter
41
+ """
42
+ dims = []
43
+ for _ in range(n):
44
+ dvar, curr = gen_dvar(curr)
45
+ dims.append(dvar)
46
+ return dims, curr
47
+
48
+
49
+ def gen_nat_constraints(list_of_dims):
50
+ """
51
+ Generate natural number constraints for dimensions
52
+ """
53
+ return [BinConstraintD(0, d, op_leq) for d in list_of_dims]
.venv/lib/python3.11/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import z3 # type: ignore[import]
3
+ HAS_Z3 = True
4
+ # dynamic type
5
+ dyn = z3.DeclareSort('Dyn')
6
+ dyn_type = z3.Const('dyn', dyn)
7
+
8
+ # dimension
9
+ dim = z3.Datatype('dim')
10
+ dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort()))
11
+ dim = dim.create()
12
+
13
+ # tensors
14
+ tensor_type = z3.Datatype('TensorType')
15
+ tensor_type.declare('Dyn', ('dyn', dyn))
16
+ tensor_type.declare('tensor1', ('0', dim))
17
+ tensor_type.declare('tensor2', ('0', dim), ('1', dim))
18
+ tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim))
19
+ tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim))
20
+ tensor_type = tensor_type.create()
21
+
22
+ # create dimension
23
+ D = dim.dim
24
+
25
+ z3_dyn = tensor_type.Dyn(dyn_type)
26
+
27
+
28
+ except ImportError:
29
+ HAS_Z3 = False
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # mypy: disable-error-code=attr-defined
2
+ from .core import unify, reify # noqa: F403
3
+ from .more import unifiable # noqa: F403
4
+ from .variable import var, isvar, vars, variables, Var # noqa: F403
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (471 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-311.pyc ADDED
Binary file (4.18 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-311.pyc ADDED
Binary file (402 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-311.pyc ADDED
Binary file (5.25 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.2 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/core.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections.abc import Iterator # type: ignore[import]
3
+ from functools import partial
4
+
5
+ from .unification_tools import assoc # type: ignore[import]
6
+ from .utils import transitive_get as walk
7
+ from .variable import isvar
8
+ from .dispatch import dispatch
9
+
10
+ __all__ = ["reify", "unify"]
11
+
12
+ ###############
13
+ # Reification #
14
+ ###############
15
+
16
+ @dispatch(Iterator, dict)
17
+ def _reify(t, s):
18
+ return map(partial(reify, s=s), t)
19
+ # return (reify(arg, s) for arg in t)
20
+ _reify
21
+
22
+ @dispatch(tuple, dict) # type: ignore[no-redef]
23
+ def _reify(t, s):
24
+ return tuple(reify(iter(t), s))
25
+ _reify
26
+
27
+ @dispatch(list, dict) # type: ignore[no-redef]
28
+ def _reify(t, s):
29
+ return list(reify(iter(t), s))
30
+ _reify
31
+
32
+ @dispatch(dict, dict) # type: ignore[no-redef]
33
+ def _reify(d, s):
34
+ return {k: reify(v, s) for k, v in d.items()}
35
+ _reify
36
+
37
+ @dispatch(object, dict) # type: ignore[no-redef]
38
+ def _reify(o, s):
39
+ return o # catch all, just return the object
40
+
41
+ def reify(e, s):
42
+ """ Replace variables of expression with substitution
43
+ >>> # xdoctest: +SKIP
44
+ >>> x, y = var(), var()
45
+ >>> e = (1, x, (3, y))
46
+ >>> s = {x: 2, y: 4}
47
+ >>> reify(e, s)
48
+ (1, 2, (3, 4))
49
+ >>> e = {1: x, 3: (y, 5)}
50
+ >>> reify(e, s)
51
+ {1: 2, 3: (4, 5)}
52
+ """
53
+ if isvar(e):
54
+ return reify(s[e], s) if e in s else e
55
+ return _reify(e, s)
56
+
57
+ ###############
58
+ # Unification #
59
+ ###############
60
+
61
+ seq = tuple, list, Iterator
62
+
63
+ @dispatch(seq, seq, dict)
64
+ def _unify(u, v, s):
65
+ if len(u) != len(v):
66
+ return False
67
+ for uu, vv in zip(u, v): # avoiding recursion
68
+ s = unify(uu, vv, s)
69
+ if s is False:
70
+ return False
71
+ return s
72
+ #
73
+ # @dispatch((set, frozenset), (set, frozenset), dict)
74
+ # def _unify(u, v, s):
75
+ # i = u & v
76
+ # u = u - i
77
+ # v = v - i
78
+ # return _unify(sorted(u), sorted(v), s)
79
+ #
80
+ #
81
+ # @dispatch(dict, dict, dict)
82
+ # def _unify(u, v, s):
83
+ # if len(u) != len(v):
84
+ # return False
85
+ # for key, uval in iteritems(u):
86
+ # if key not in v:
87
+ # return False
88
+ # s = unify(uval, v[key], s)
89
+ # if s is False:
90
+ # return False
91
+ # return s
92
+ #
93
+ #
94
+ # @dispatch(object, object, dict)
95
+ # def _unify(u, v, s):
96
+ # return False # catch all
97
+
98
+
99
+ @dispatch(object, object, dict)
100
+ def unify(u, v, s): # no check at the moment
101
+ """ Find substitution so that u == v while satisfying s
102
+ >>> x = var('x')
103
+ >>> unify((1, x), (1, 2), {})
104
+ {~x: 2}
105
+ """
106
+ u = walk(u, s)
107
+ v = walk(v, s)
108
+ if u == v:
109
+ return s
110
+ if isvar(u):
111
+ return assoc(s, u, v)
112
+ if isvar(v):
113
+ return assoc(s, v, u)
114
+ return _unify(u, v, s)
115
+ unify
116
+
117
+ @dispatch(object, object) # type: ignore[no-redef]
118
+ def unify(u, v):
119
+ return unify(u, v, {})
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/dispatch.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from .multipledispatch import dispatch # type: ignore[import]
3
+
4
+ namespace = {} # type: ignore[var-annotated]
5
+
6
+ dispatch = partial(dispatch, namespace=namespace)
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/match.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from .core import unify, reify # type: ignore[attr-defined]
3
+ from .variable import isvar
4
+ from .utils import _toposort, freeze
5
+ from .unification_tools import groupby, first # type: ignore[import]
6
+
7
+
8
+ class Dispatcher:
9
+ def __init__(self, name):
10
+ self.name = name
11
+ self.funcs = {}
12
+ self.ordering = []
13
+
14
+ def add(self, signature, func):
15
+ self.funcs[freeze(signature)] = func
16
+ self.ordering = ordering(self.funcs)
17
+
18
+ def __call__(self, *args, **kwargs):
19
+ func, s = self.resolve(args)
20
+ return func(*args, **kwargs)
21
+
22
+ def resolve(self, args):
23
+ n = len(args)
24
+ for signature in self.ordering:
25
+ if len(signature) != n:
26
+ continue
27
+ s = unify(freeze(args), signature)
28
+ if s is not False:
29
+ result = self.funcs[signature]
30
+ return result, s
31
+ raise NotImplementedError("No match found. \nKnown matches: "
32
+ + str(self.ordering) + "\nInput: " + str(args))
33
+
34
+ def register(self, *signature):
35
+ def _(func):
36
+ self.add(signature, func)
37
+ return self
38
+ return _
39
+
40
+
41
+ class VarDispatcher(Dispatcher):
42
+ """ A dispatcher that calls functions with variable names
43
+ >>> # xdoctest: +SKIP
44
+ >>> d = VarDispatcher('d')
45
+ >>> x = var('x')
46
+ >>> @d.register('inc', x)
47
+ ... def f(x):
48
+ ... return x + 1
49
+ >>> @d.register('double', x)
50
+ ... def f(x):
51
+ ... return x * 2
52
+ >>> d('inc', 10)
53
+ 11
54
+ >>> d('double', 10)
55
+ 20
56
+ """
57
+ def __call__(self, *args, **kwargs):
58
+ func, s = self.resolve(args)
59
+ d = {k.token: v for k, v in s.items()}
60
+ return func(**d)
61
+
62
+
63
+ global_namespace = {} # type: ignore[var-annotated]
64
+
65
+
66
+ def match(*signature, **kwargs):
67
+ namespace = kwargs.get('namespace', global_namespace)
68
+ dispatcher = kwargs.get('Dispatcher', Dispatcher)
69
+
70
+ def _(func):
71
+ name = func.__name__
72
+
73
+ if name not in namespace:
74
+ namespace[name] = dispatcher(name)
75
+ d = namespace[name]
76
+
77
+ d.add(signature, func)
78
+
79
+ return d
80
+ return _
81
+
82
+
83
+ def supercedes(a, b):
84
+ """ ``a`` is a more specific match than ``b`` """
85
+ if isvar(b) and not isvar(a):
86
+ return True
87
+ s = unify(a, b)
88
+ if s is False:
89
+ return False
90
+ s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
91
+ if reify(a, s) == a:
92
+ return True
93
+ if reify(b, s) == b:
94
+ return False
95
+
96
+
97
+ # Taken from multipledispatch
98
+ def edge(a, b, tie_breaker=hash):
99
+ """ A should be checked before B
100
+ Tie broken by tie_breaker, defaults to ``hash``
101
+ """
102
+ if supercedes(a, b):
103
+ if supercedes(b, a):
104
+ return tie_breaker(a) > tie_breaker(b)
105
+ else:
106
+ return True
107
+ return False
108
+
109
+
110
+ # Taken from multipledispatch
111
+ def ordering(signatures):
112
+ """ A sane ordering of signatures to check, first to last
113
+ Topological sort of edges as given by ``edge`` and ``supercedes``
114
+ """
115
+ signatures = list(map(tuple, signatures))
116
+ edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
117
+ edges = groupby(first, edges)
118
+ for s in signatures:
119
+ if s not in edges:
120
+ edges[s] = []
121
+ edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
122
+ return _toposort(edges)
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/more.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from .core import unify, reify # type: ignore[attr-defined]
3
+ from .dispatch import dispatch
4
+
5
+
6
+ def unifiable(cls):
7
+ """ Register standard unify and reify operations on class
8
+ This uses the type and __dict__ or __slots__ attributes to define the
9
+ nature of the term
10
+ See Also:
11
+ >>> # xdoctest: +SKIP
12
+ >>> class A(object):
13
+ ... def __init__(self, a, b):
14
+ ... self.a = a
15
+ ... self.b = b
16
+ >>> unifiable(A)
17
+ <class 'unification.more.A'>
18
+ >>> x = var('x')
19
+ >>> a = A(1, 2)
20
+ >>> b = A(1, x)
21
+ >>> unify(a, b, {})
22
+ {~x: 2}
23
+ """
24
+ _unify.add((cls, cls, dict), unify_object)
25
+ _reify.add((cls, dict), reify_object)
26
+
27
+ return cls
28
+
29
+
30
+ #########
31
+ # Reify #
32
+ #########
33
+
34
+
35
+ def reify_object(o, s):
36
+ """ Reify a Python object with a substitution
37
+ >>> # xdoctest: +SKIP
38
+ >>> class Foo(object):
39
+ ... def __init__(self, a, b):
40
+ ... self.a = a
41
+ ... self.b = b
42
+ ... def __str__(self):
43
+ ... return "Foo(%s, %s)"%(str(self.a), str(self.b))
44
+ >>> x = var('x')
45
+ >>> f = Foo(1, x)
46
+ >>> print(f)
47
+ Foo(1, ~x)
48
+ >>> print(reify_object(f, {x: 2}))
49
+ Foo(1, 2)
50
+ """
51
+ if hasattr(o, '__slots__'):
52
+ return _reify_object_slots(o, s)
53
+ else:
54
+ return _reify_object_dict(o, s)
55
+
56
+
57
+ def _reify_object_dict(o, s):
58
+ obj = object.__new__(type(o))
59
+ d = reify(o.__dict__, s)
60
+ if d == o.__dict__:
61
+ return o
62
+ obj.__dict__.update(d)
63
+ return obj
64
+
65
+
66
+ def _reify_object_slots(o, s):
67
+ attrs = [getattr(o, attr) for attr in o.__slots__]
68
+ new_attrs = reify(attrs, s)
69
+ if attrs == new_attrs:
70
+ return o
71
+ else:
72
+ newobj = object.__new__(type(o))
73
+ for slot, attr in zip(o.__slots__, new_attrs):
74
+ setattr(newobj, slot, attr)
75
+ return newobj
76
+
77
+
78
+ @dispatch(slice, dict)
79
+ def _reify(o, s):
80
+ """ Reify a Python ``slice`` object """
81
+ return slice(*reify((o.start, o.stop, o.step), s))
82
+
83
+
84
+ #########
85
+ # Unify #
86
+ #########
87
+
88
+
89
+ def unify_object(u, v, s):
90
+ """ Unify two Python objects
91
+ Unifies their type and ``__dict__`` attributes
92
+ >>> # xdoctest: +SKIP
93
+ >>> class Foo(object):
94
+ ... def __init__(self, a, b):
95
+ ... self.a = a
96
+ ... self.b = b
97
+ ... def __str__(self):
98
+ ... return "Foo(%s, %s)"%(str(self.a), str(self.b))
99
+ >>> x = var('x')
100
+ >>> f = Foo(1, x)
101
+ >>> g = Foo(1, 2)
102
+ >>> unify_object(f, g, {})
103
+ {~x: 2}
104
+ """
105
+ if type(u) != type(v):
106
+ return False
107
+ if hasattr(u, '__slots__'):
108
+ return unify([getattr(u, slot) for slot in u.__slots__],
109
+ [getattr(v, slot) for slot in v.__slots__],
110
+ s)
111
+ else:
112
+ return unify(u.__dict__, v.__dict__, s)
113
+
114
+
115
+ @dispatch(slice, slice, dict)
116
+ def _unify(u, v, s):
117
+ """ Unify a Python ``slice`` object """
118
+ return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .core import dispatch
2
+ from .dispatcher import (Dispatcher, halt_ordering, restart_ordering,
3
+ MDNotImplementedError)
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (464 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-311.pyc ADDED
Binary file (8.66 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-311.pyc ADDED
Binary file (3.5 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-311.pyc ADDED
Binary file (22.4 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-311.pyc ADDED
Binary file (6.28 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-311.pyc ADDED
Binary file (4.74 kB). View file
 
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from .utils import _toposort, groupby
3
+ from .variadic import isvariadic
4
+ import operator
5
+
6
+ __all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
7
+ "edge", "ordering"]
8
+
9
+ class AmbiguityWarning(Warning):
10
+ pass
11
+
12
+
13
+ def supercedes(a, b):
14
+ """ A is consistent and strictly more specific than B """
15
+ if len(a) < len(b):
16
+ # only case is if a is empty and b is variadic
17
+ return not a and len(b) == 1 and isvariadic(b[-1])
18
+ elif len(a) == len(b):
19
+ return all(map(issubclass, a, b))
20
+ else:
21
+ # len(a) > len(b)
22
+ p1 = 0
23
+ p2 = 0
24
+ while p1 < len(a) and p2 < len(b):
25
+ cur_a = a[p1]
26
+ cur_b = b[p2]
27
+ if not (isvariadic(cur_a) or isvariadic(cur_b)):
28
+ if not issubclass(cur_a, cur_b):
29
+ return False
30
+ p1 += 1
31
+ p2 += 1
32
+ elif isvariadic(cur_a):
33
+ assert p1 == len(a) - 1
34
+ return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
35
+ elif isvariadic(cur_b):
36
+ assert p2 == len(b) - 1
37
+ if not issubclass(cur_a, cur_b):
38
+ return False
39
+ p1 += 1
40
+ return p2 == len(b) - 1 and p1 == len(a)
41
+
42
+
43
+ def consistent(a, b):
44
+ """ It is possible for an argument list to satisfy both A and B """
45
+
46
+ # Need to check for empty args
47
+ if not a:
48
+ return not b or isvariadic(b[0])
49
+ if not b:
50
+ return not a or isvariadic(a[0])
51
+
52
+ # Non-empty args check for mutual subclasses
53
+ if len(a) == len(b):
54
+ return all(issubclass(aa, bb) or issubclass(bb, aa)
55
+ for aa, bb in zip(a, b))
56
+ else:
57
+ p1 = 0
58
+ p2 = 0
59
+ while p1 < len(a) and p2 < len(b):
60
+ cur_a = a[p1]
61
+ cur_b = b[p2]
62
+ if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
63
+ return False
64
+ if not (isvariadic(cur_a) or isvariadic(cur_b)):
65
+ p1 += 1
66
+ p2 += 1
67
+ elif isvariadic(cur_a):
68
+ p2 += 1
69
+ elif isvariadic(cur_b):
70
+ p1 += 1
71
+ # We only need to check for variadic ends
72
+ # Variadic types are guaranteed to be the last element
73
+ return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined]
74
+ isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined]
75
+
76
+
77
+ def ambiguous(a, b):
78
+ """ A is consistent with B but neither is strictly more specific """
79
+ return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
80
+
81
+
82
+ def ambiguities(signatures):
83
+ """ All signature pairs such that A is ambiguous with B """
84
+ signatures = list(map(tuple, signatures))
85
+ return {(a, b) for a in signatures for b in signatures
86
+ if hash(a) < hash(b)
87
+ and ambiguous(a, b)
88
+ and not any(supercedes(c, a) and supercedes(c, b)
89
+ for c in signatures)}
90
+
91
+
92
+ def super_signature(signatures):
93
+ """ A signature that would break ambiguities """
94
+ n = len(signatures[0])
95
+ assert all(len(s) == n for s in signatures)
96
+
97
+ return [max((type.mro(sig[i]) for sig in signatures), key=len)[0]
98
+ for i in range(n)]
99
+
100
+
101
+ def edge(a, b, tie_breaker=hash):
102
+ """ A should be checked before B
103
+ Tie broken by tie_breaker, defaults to ``hash``
104
+ """
105
+ # A either supercedes B and B does not supercede A or if B does then call
106
+ # tie_breaker
107
+ return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b))
108
+
109
+
110
+ def ordering(signatures):
111
+ """ A sane ordering of signatures to check, first to last
112
+ Topological sort of edges as given by ``edge`` and ``supercedes``
113
+ """
114
+ signatures = list(map(tuple, signatures))
115
+ edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
116
+ edges = groupby(operator.itemgetter(0), edges)
117
+ for s in signatures:
118
+ if s not in edges:
119
+ edges[s] = []
120
+ edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined]
121
+ return _toposort(edges)
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/core.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import inspect
3
+ import sys
4
+
5
+ from .dispatcher import Dispatcher, MethodDispatcher
6
+
7
+ global_namespace = {} # type: ignore[var-annotated]
8
+
9
+ __all__ = ["dispatch", "ismethod"]
10
+
11
+ def dispatch(*types, **kwargs):
12
+ """ Dispatch function on the types of the inputs
13
+ Supports dispatch on all non-keyword arguments.
14
+ Collects implementations based on the function name. Ignores namespaces.
15
+ If ambiguous type signatures occur a warning is raised when the function is
16
+ defined suggesting the additional method to break the ambiguity.
17
+
18
+ Example:
19
+ >>> # xdoctest: +SKIP
20
+ >>> @dispatch(int)
21
+ ... def f(x):
22
+ ... return x + 1
23
+ >>> @dispatch(float)
24
+ ... def f(x):
25
+ ... return x - 1
26
+ >>> # xdoctest: +SKIP
27
+ >>> f(3)
28
+ 4
29
+ >>> f(3.0)
30
+ 2.0
31
+ >>> # Specify an isolated namespace with the namespace keyword argument
32
+ >>> my_namespace = {}
33
+ >>> @dispatch(int, namespace=my_namespace)
34
+ ... def foo(x):
35
+ ... return x + 1
36
+ >>> # Dispatch on instance methods within classes
37
+ >>> class MyClass(object):
38
+ ... @dispatch(list)
39
+ ... def __init__(self, data):
40
+ ... self.data = data
41
+ ... @dispatch(int)
42
+ ... def __init__(self, datum):
43
+ ... self.data = [datum]
44
+ >>> MyClass([1, 2, 3]).data
45
+ [1, 2, 3]
46
+ >>> MyClass(3).data
47
+ [3]
48
+ """
49
+ namespace = kwargs.get('namespace', global_namespace)
50
+
51
+ types = tuple(types)
52
+
53
+ def _df(func):
54
+ name = func.__name__
55
+
56
+ if ismethod(func):
57
+ dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr]
58
+ name, # type: ignore[union-attr]
59
+ MethodDispatcher(name),
60
+ )
61
+ else:
62
+ if name not in namespace:
63
+ namespace[name] = Dispatcher(name)
64
+ dispatcher = namespace[name]
65
+
66
+ dispatcher.add(types, func)
67
+ return dispatcher
68
+ return _df
69
+
70
+
71
+ def ismethod(func):
72
+ """ Is func a method?
73
+ Note that this has to work as the method is defined but before the class is
74
+ defined. At this stage methods look like functions.
75
+ """
76
+ if hasattr(inspect, "signature"):
77
+ signature = inspect.signature(func)
78
+ return signature.parameters.get('self', None) is not None
79
+ else:
80
+ if sys.version_info.major < 3:
81
+ spec = inspect.getargspec(func) # type: ignore[attr-defined]
82
+ else:
83
+ spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment]
84
+ return spec and spec.args and spec.args[0] == 'self'
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from warnings import warn
3
+ import inspect
4
+ from typing_extensions import deprecated
5
+ from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
6
+ from .utils import expand_tuples
7
+ from .variadic import Variadic, isvariadic
8
+ import itertools as itl
9
+
10
+ __all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
11
+ "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
12
+
13
+ class MDNotImplementedError(NotImplementedError):
14
+ """ A NotImplementedError for multiple dispatch """
15
+
16
+
17
+ def ambiguity_warn(dispatcher, ambiguities):
18
+ """ Raise warning when ambiguity is detected
19
+ Parameters
20
+ ----------
21
+ dispatcher : Dispatcher
22
+ The dispatcher on which the ambiguity was detected
23
+ ambiguities : set
24
+ Set of type signature pairs that are ambiguous within this dispatcher
25
+ See Also:
26
+ Dispatcher.add
27
+ warning_text
28
+ """
29
+ warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
30
+
31
+
32
+ @deprecated(
33
+ "`halt_ordering` is deprecated, you can safely remove this call.",
34
+ category=FutureWarning,
35
+ )
36
+ def halt_ordering():
37
+ """Deprecated interface to temporarily disable ordering."""
38
+
39
+
40
+ @deprecated(
41
+ "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, "
42
+ "you should call the `reorder()` method on each dispatcher.",
43
+ category=FutureWarning,
44
+ )
45
+ def restart_ordering(on_ambiguity=ambiguity_warn):
46
+ """Deprecated interface to temporarily resume ordering."""
47
+
48
+
49
+ def variadic_signature_matches_iter(types, full_signature):
50
+ """Check if a set of input types matches a variadic signature.
51
+ Notes
52
+ -----
53
+ The algorithm is as follows:
54
+ Initialize the current signature to the first in the sequence
55
+ For each type in `types`:
56
+ If the current signature is variadic
57
+ If the type matches the signature
58
+ yield True
59
+ Else
60
+ Try to get the next signature
61
+ If no signatures are left we can't possibly have a match
62
+ so yield False
63
+ Else
64
+ yield True if the type matches the current signature
65
+ Get the next signature
66
+ """
67
+ sigiter = iter(full_signature)
68
+ sig = next(sigiter)
69
+ for typ in types:
70
+ matches = issubclass(typ, sig)
71
+ yield matches
72
+ if not isvariadic(sig):
73
+ # we're not matching a variadic argument, so move to the next
74
+ # element in the signature
75
+ sig = next(sigiter)
76
+ else:
77
+ try:
78
+ sig = next(sigiter)
79
+ except StopIteration:
80
+ assert isvariadic(sig)
81
+ yield True
82
+ else:
83
+ # We have signature items left over, so all of our arguments
84
+ # haven't matched
85
+ yield False
86
+
87
+
88
+ def variadic_signature_matches(types, full_signature):
89
+ # No arguments always matches a variadic signature
90
+ assert full_signature
91
+ return all(variadic_signature_matches_iter(types, full_signature))
92
+
93
+
94
+ class Dispatcher:
95
+ """ Dispatch methods based on type signature
96
+ Use ``dispatch`` to add implementations
97
+ Examples
98
+ --------
99
+ >>> # xdoctest: +SKIP("bad import name")
100
+ >>> from multipledispatch import dispatch
101
+ >>> @dispatch(int)
102
+ ... def f(x):
103
+ ... return x + 1
104
+ >>> @dispatch(float)
105
+ ... def f(x):
106
+ ... return x - 1
107
+ >>> f(3)
108
+ 4
109
+ >>> f(3.0)
110
+ 2.0
111
+ """
112
+ __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
113
+
114
+ def __init__(self, name, doc=None):
115
+ self.name = self.__name__ = name
116
+ self.funcs = {}
117
+ self.doc = doc
118
+
119
+ self._cache = {}
120
+
121
+ def register(self, *types, **kwargs):
122
+ """ register dispatcher with new implementation
123
+ >>> # xdoctest: +SKIP
124
+ >>> f = Dispatcher('f')
125
+ >>> @f.register(int)
126
+ ... def inc(x):
127
+ ... return x + 1
128
+ >>> @f.register(float)
129
+ ... def dec(x):
130
+ ... return x - 1
131
+ >>> @f.register(list)
132
+ ... @f.register(tuple)
133
+ ... def reverse(x):
134
+ ... return x[::-1]
135
+ >>> f(1)
136
+ 2
137
+ >>> f(1.0)
138
+ 0.0
139
+ >>> f([1, 2, 3])
140
+ [3, 2, 1]
141
+ """
142
+ def _df(func):
143
+ self.add(types, func, **kwargs) # type: ignore[call-arg]
144
+ return func
145
+ return _df
146
+
147
+ @classmethod
148
+ def get_func_params(cls, func):
149
+ if hasattr(inspect, "signature"):
150
+ sig = inspect.signature(func)
151
+ return sig.parameters.values()
152
+
153
+ @classmethod
154
+ def get_func_annotations(cls, func):
155
+ """ get annotations of function positional parameters
156
+ """
157
+ params = cls.get_func_params(func)
158
+ if params:
159
+ Parameter = inspect.Parameter
160
+
161
+ params = (param for param in params
162
+ if param.kind in
163
+ (Parameter.POSITIONAL_ONLY,
164
+ Parameter.POSITIONAL_OR_KEYWORD))
165
+
166
+ annotations = tuple(
167
+ param.annotation
168
+ for param in params)
169
+
170
+ if all(ann is not Parameter.empty for ann in annotations):
171
+ return annotations
172
+
173
+ def add(self, signature, func):
174
+ """ Add new types/method pair to dispatcher
175
+ >>> # xdoctest: +SKIP
176
+ >>> D = Dispatcher('add')
177
+ >>> D.add((int, int), lambda x, y: x + y)
178
+ >>> D.add((float, float), lambda x, y: x + y)
179
+ >>> D(1, 2)
180
+ 3
181
+ >>> D(1, 2.0)
182
+ Traceback (most recent call last):
183
+ ...
184
+ NotImplementedError: Could not find signature for add: <int, float>
185
+ >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
186
+ >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
187
+ >>> # as inputs. See ``ambiguity_warn`` for an example.
188
+ """
189
+ # Handle annotations
190
+ if not signature:
191
+ annotations = self.get_func_annotations(func)
192
+ if annotations:
193
+ signature = annotations
194
+
195
+ # Handle union types
196
+ if any(isinstance(typ, tuple) for typ in signature):
197
+ for typs in expand_tuples(signature):
198
+ self.add(typs, func)
199
+ return
200
+
201
+ new_signature = []
202
+
203
+ for index, typ in enumerate(signature, start=1):
204
+ if not isinstance(typ, (type, list)):
205
+ str_sig = ', '.join(c.__name__ if isinstance(c, type)
206
+ else str(c) for c in signature)
207
+ raise TypeError(f"Tried to dispatch on non-type: {typ}\n"
208
+ f"In signature: <{str_sig}>\n"
209
+ f"In function: {self.name}")
210
+
211
+ # handle variadic signatures
212
+ if isinstance(typ, list):
213
+ if index != len(signature):
214
+ raise TypeError(
215
+ 'Variadic signature must be the last element'
216
+ )
217
+
218
+ if len(typ) != 1:
219
+ raise TypeError(
220
+ 'Variadic signature must contain exactly one element. '
221
+ 'To use a variadic union type place the desired types '
222
+ 'inside of a tuple, e.g., [(int, str)]'
223
+ )
224
+ new_signature.append(Variadic[typ[0]])
225
+ else:
226
+ new_signature.append(typ)
227
+
228
+ self.funcs[tuple(new_signature)] = func
229
+ self._cache.clear()
230
+
231
+ try:
232
+ del self._ordering
233
+ except AttributeError:
234
+ pass
235
+
236
+ @property
237
+ def ordering(self):
238
+ try:
239
+ return self._ordering
240
+ except AttributeError:
241
+ return self.reorder()
242
+
243
+ def reorder(self, on_ambiguity=ambiguity_warn):
244
+ self._ordering = od = ordering(self.funcs)
245
+ amb = ambiguities(self.funcs)
246
+ if amb:
247
+ on_ambiguity(self, amb)
248
+ return od
249
+
250
+ def __call__(self, *args, **kwargs):
251
+ types = tuple([type(arg) for arg in args])
252
+ try:
253
+ func = self._cache[types]
254
+ except KeyError as e:
255
+ func = self.dispatch(*types)
256
+ if not func:
257
+ raise NotImplementedError(
258
+ f'Could not find signature for {self.name}: <{str_signature(types)}>') from e
259
+ self._cache[types] = func
260
+ try:
261
+ return func(*args, **kwargs)
262
+
263
+ except MDNotImplementedError as e:
264
+ funcs = self.dispatch_iter(*types)
265
+ next(funcs) # burn first
266
+ for func in funcs:
267
+ try:
268
+ return func(*args, **kwargs)
269
+ except MDNotImplementedError:
270
+ pass
271
+
272
+ raise NotImplementedError(
273
+ "Matching functions for "
274
+ f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e
275
+
276
+ def __str__(self):
277
+ return f"<dispatched {self.name}>"
278
+ __repr__ = __str__
279
+
280
+ def dispatch(self, *types):
281
+ """Determine appropriate implementation for this type signature
282
+ This method is internal. Users should call this object as a function.
283
+ Implementation resolution occurs within the ``__call__`` method.
284
+ >>> # xdoctest: +SKIP
285
+ >>> from multipledispatch import dispatch
286
+ >>> @dispatch(int)
287
+ ... def inc(x):
288
+ ... return x + 1
289
+ >>> implementation = inc.dispatch(int)
290
+ >>> implementation(3)
291
+ 4
292
+ >>> print(inc.dispatch(float))
293
+ None
294
+ See Also:
295
+ ``multipledispatch.conflict`` - module to determine resolution order
296
+ """
297
+
298
+ if types in self.funcs:
299
+ return self.funcs[types]
300
+
301
+ try:
302
+ return next(self.dispatch_iter(*types))
303
+ except StopIteration:
304
+ return None
305
+
306
+ def dispatch_iter(self, *types):
307
+
308
+ n = len(types)
309
+ for signature in self.ordering:
310
+ if len(signature) == n and all(map(issubclass, types, signature)):
311
+ result = self.funcs[signature]
312
+ yield result
313
+ elif len(signature) and isvariadic(signature[-1]):
314
+ if variadic_signature_matches(types, signature):
315
+ result = self.funcs[signature]
316
+ yield result
317
+
318
+ @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning)
319
+ def resolve(self, types):
320
+ """ Determine appropriate implementation for this type signature
321
+ .. deprecated:: 0.4.4
322
+ Use ``dispatch(*types)`` instead
323
+ """
324
+ return self.dispatch(*types)
325
+
326
+ def __getstate__(self):
327
+ return {'name': self.name,
328
+ 'funcs': self.funcs}
329
+
330
+ def __setstate__(self, d):
331
+ self.name = d['name']
332
+ self.funcs = d['funcs']
333
+ self._ordering = ordering(self.funcs)
334
+ self._cache = {}
335
+
336
+ @property
337
+ def __doc__(self):
338
+ docs = [f"Multiply dispatched method: {self.name}"]
339
+
340
+ if self.doc:
341
+ docs.append(self.doc)
342
+
343
+ other = []
344
+ for sig in self.ordering[::-1]:
345
+ func = self.funcs[sig]
346
+ if func.__doc__:
347
+ s = f'Inputs: <{str_signature(sig)}>\n'
348
+ s += '-' * len(s) + '\n'
349
+ s += func.__doc__.strip()
350
+ docs.append(s)
351
+ else:
352
+ other.append(str_signature(sig))
353
+
354
+ if other:
355
+ docs.append('Other signatures:\n ' + '\n '.join(other))
356
+
357
+ return '\n\n'.join(docs)
358
+
359
+ def _help(self, *args):
360
+ return self.dispatch(*map(type, args)).__doc__
361
+
362
+ def help(self, *args, **kwargs):
363
+ """ Print docstring for the function corresponding to inputs """
364
+ print(self._help(*args))
365
+
366
+ def _source(self, *args):
367
+ func = self.dispatch(*map(type, args))
368
+ if not func:
369
+ raise TypeError("No function found")
370
+ return source(func)
371
+
372
+ def source(self, *args, **kwargs):
373
+ """ Print source code for the function corresponding to inputs """
374
+ print(self._source(*args))
375
+
376
+
377
+ def source(func):
378
+ s = f'File: {inspect.getsourcefile(func)}\n\n'
379
+ s = s + inspect.getsource(func)
380
+ return s
381
+
382
+
383
+ class MethodDispatcher(Dispatcher):
384
+ """ Dispatch methods based on type signature
385
+ See Also:
386
+ Dispatcher
387
+ """
388
+ __slots__ = ('obj', 'cls')
389
+
390
+ @classmethod
391
+ def get_func_params(cls, func):
392
+ if hasattr(inspect, "signature"):
393
+ sig = inspect.signature(func)
394
+ return itl.islice(sig.parameters.values(), 1, None)
395
+
396
+ def __get__(self, instance, owner):
397
+ self.obj = instance
398
+ self.cls = owner
399
+ return self
400
+
401
+ def __call__(self, *args, **kwargs):
402
+ types = tuple([type(arg) for arg in args])
403
+ func = self.dispatch(*types)
404
+ if not func:
405
+ raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>')
406
+ return func(self.obj, *args, **kwargs)
407
+
408
+
409
+ def str_signature(sig):
410
+ """ String representation of type signature
411
+ >>> str_signature((int, float))
412
+ 'int, float'
413
+ """
414
+ return ', '.join(cls.__name__ for cls in sig)
415
+
416
+
417
+ def warning_text(name, amb):
418
+ """ The text for ambiguity warnings """
419
+ text = f"\nAmbiguities exist in dispatched function {name}\n\n"
420
+ text += "The following signatures may result in ambiguous behavior:\n"
421
+ for pair in amb:
422
+ text += "\t" + \
423
+ ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
424
+ text += "\n\nConsider making the following additions:\n\n"
425
+ text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
426
+ + f')\ndef {name}(...)' for s in amb])
427
+ return text
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from collections import OrderedDict
3
+
4
+ __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
5
+
6
+ def raises(err, lamda):
7
+ try:
8
+ lamda()
9
+ return False
10
+ except err:
11
+ return True
12
+
13
+
14
+ def expand_tuples(L):
15
+ """
16
+ >>> expand_tuples([1, (2, 3)])
17
+ [(1, 2), (1, 3)]
18
+ >>> expand_tuples([1, 2])
19
+ [(1, 2)]
20
+ """
21
+ if not L:
22
+ return [()]
23
+ elif not isinstance(L[0], tuple):
24
+ rest = expand_tuples(L[1:])
25
+ return [(L[0],) + t for t in rest]
26
+ else:
27
+ rest = expand_tuples(L[1:])
28
+ return [(item,) + t for t in rest for item in L[0]]
29
+
30
+
31
+ # Taken from theano/theano/gof/sched.py
32
+ # Avoids licensing issues because this was written by Matthew Rocklin
33
+ def _toposort(edges):
34
+ """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
35
+ inputs:
36
+ edges - a dict of the form {a: {b, c}} where b and c depend on a
37
+ outputs:
38
+ L - an ordered list of nodes that satisfy the dependencies of edges
39
+ >>> _toposort({1: (2, 3), 2: (3, )})
40
+ [1, 2, 3]
41
+ >>> # Closely follows the wikipedia page [2]
42
+ >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
43
+ >>> # Communications of the ACM
44
+ >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
45
+ """
46
+ incoming_edges = reverse_dict(edges)
47
+ incoming_edges = OrderedDict((k, set(val))
48
+ for k, val in incoming_edges.items())
49
+ S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
50
+ L = []
51
+
52
+ while S:
53
+ n, _ = S.popitem()
54
+ L.append(n)
55
+ for m in edges.get(n, ()):
56
+ assert n in incoming_edges[m]
57
+ incoming_edges[m].remove(n)
58
+ if not incoming_edges[m]:
59
+ S[m] = None
60
+ if any(incoming_edges.get(v, None) for v in edges):
61
+ raise ValueError("Input has cycles")
62
+ return L
63
+
64
+
65
+ def reverse_dict(d):
66
+ """Reverses direction of dependence dict
67
+ >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
68
+ >>> reverse_dict(d) # doctest: +SKIP
69
+ {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
70
+ :note: dict order are not deterministic. As we iterate on the
71
+ input dict, it make the output of this function depend on the
72
+ dict order. So this function output order should be considered
73
+ as undeterministic.
74
+ """
75
+ result = OrderedDict() # type: ignore[var-annotated]
76
+ for key in d:
77
+ for val in d[key]:
78
+ result[val] = result.get(val, ()) + (key,)
79
+ return result
80
+
81
+
82
+ # Taken from toolz
83
+ # Avoids licensing issues because this version was authored by Matthew Rocklin
84
+ def groupby(func, seq):
85
+ """ Group a collection by a key function
86
+ >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
87
+ >>> groupby(len, names) # doctest: +SKIP
88
+ {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
89
+ >>> iseven = lambda x: x % 2 == 0
90
+ >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
91
+ {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
92
+ See Also:
93
+ ``countby``
94
+ """
95
+
96
+ d = OrderedDict() # type: ignore[var-annotated]
97
+ for item in seq:
98
+ key = func(item)
99
+ if key not in d:
100
+ d[key] = []
101
+ d[key].append(item)
102
+ return d
103
+
104
+
105
+ def typename(type):
106
+ """Get the name of `type`.
107
+ Parameters
108
+ ----------
109
+ type : Union[Type, Tuple[Type]]
110
+ Returns
111
+ -------
112
+ str
113
+ The name of `type` or a tuple of the names of the types in `type`.
114
+ Examples
115
+ --------
116
+ >>> typename(int)
117
+ 'int'
118
+ >>> typename((int, float))
119
+ '(int, float)'
120
+ """
121
+ try:
122
+ return type.__name__
123
+ except AttributeError:
124
+ if len(type) == 1:
125
+ return typename(*type)
126
+ return f"({', '.join(map(typename, type))})"
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from .utils import typename
3
+
4
+ __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
5
+
6
+ class VariadicSignatureType(type):
7
+ # checking if subclass is a subclass of self
8
+ def __subclasscheck__(cls, subclass):
9
+ other_type = (subclass.variadic_type if isvariadic(subclass)
10
+ else (subclass,))
11
+ return subclass is cls or all(
12
+ issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined]
13
+ )
14
+
15
+ def __eq__(cls, other):
16
+ """
17
+ Return True if other has the same variadic type
18
+ Parameters
19
+ ----------
20
+ other : object (type)
21
+ The object (type) to check
22
+ Returns
23
+ -------
24
+ bool
25
+ Whether or not `other` is equal to `self`
26
+ """
27
+ return (isvariadic(other) and
28
+ set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined]
29
+
30
+ def __hash__(cls):
31
+ return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined]
32
+
33
+
34
+ def isvariadic(obj):
35
+ """Check whether the type `obj` is variadic.
36
+ Parameters
37
+ ----------
38
+ obj : type
39
+ The type to check
40
+ Returns
41
+ -------
42
+ bool
43
+ Whether or not `obj` is variadic
44
+ Examples
45
+ --------
46
+ >>> # xdoctest: +SKIP
47
+ >>> isvariadic(int)
48
+ False
49
+ >>> isvariadic(Variadic[int])
50
+ True
51
+ """
52
+ return isinstance(obj, VariadicSignatureType)
53
+
54
+
55
+ class VariadicSignatureMeta(type):
56
+ """A metaclass that overrides ``__getitem__`` on the class. This is used to
57
+ generate a new type for Variadic signatures. See the Variadic class for
58
+ examples of how this behaves.
59
+ """
60
+ def __getitem__(cls, variadic_type):
61
+ if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
62
+ raise ValueError("Variadic types must be type or tuple of types"
63
+ " (Variadic[int] or Variadic[(int, float)]")
64
+
65
+ if not isinstance(variadic_type, tuple):
66
+ variadic_type = variadic_type,
67
+ return VariadicSignatureType(
68
+ f'Variadic[{typename(variadic_type)}]',
69
+ (),
70
+ dict(variadic_type=variadic_type, __slots__=())
71
+ )
72
+
73
+
74
+ class Variadic(metaclass=VariadicSignatureMeta):
75
+ """A class whose getitem method can be used to generate a new type
76
+ representing a specific variadic signature.
77
+ Examples
78
+ --------
79
+ >>> # xdoctest: +SKIP
80
+ >>> Variadic[int] # any number of int arguments
81
+ <class 'multipledispatch.variadic.Variadic[int]'>
82
+ >>> Variadic[(int, str)] # any number of one of int or str arguments
83
+ <class 'multipledispatch.variadic.Variadic[(int, str)]'>
84
+ >>> issubclass(int, Variadic[int])
85
+ True
86
+ >>> issubclass(int, Variadic[(int, str)])
87
+ True
88
+ >>> issubclass(str, Variadic[(int, str)])
89
+ True
90
+ >>> issubclass(float, Variadic[(int, str)])
91
+ False
92
+ """
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/unification_tools.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ import collections
3
+ import operator
4
+ from functools import reduce
5
+ from collections.abc import Mapping
6
+
7
+ __all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
8
+ 'valfilter', 'keyfilter', 'itemfilter',
9
+ 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in']
10
+
11
+
12
+ def _get_factory(f, kwargs):
13
+ factory = kwargs.pop('factory', dict)
14
+ if kwargs:
15
+ raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'")
16
+ return factory
17
+
18
+
19
+ def merge(*dicts, **kwargs):
20
+ """ Merge a collection of dictionaries
21
+
22
+ >>> merge({1: 'one'}, {2: 'two'})
23
+ {1: 'one', 2: 'two'}
24
+
25
+ Later dictionaries have precedence
26
+
27
+ >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
28
+ {1: 2, 3: 3, 4: 4}
29
+
30
+ See Also:
31
+ merge_with
32
+ """
33
+ if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
34
+ dicts = dicts[0]
35
+ factory = _get_factory(merge, kwargs)
36
+
37
+ rv = factory()
38
+ for d in dicts:
39
+ rv.update(d)
40
+ return rv
41
+
42
+
43
+ def merge_with(func, *dicts, **kwargs):
44
+ """ Merge dictionaries and apply function to combined values
45
+
46
+ A key may occur in more than one dict, and all values mapped from the key
47
+ will be passed to the function as a list, such as func([val1, val2, ...]).
48
+
49
+ >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
50
+ {1: 11, 2: 22}
51
+
52
+ >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
53
+ {1: 1, 2: 2, 3: 30}
54
+
55
+ See Also:
56
+ merge
57
+ """
58
+ if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
59
+ dicts = dicts[0]
60
+ factory = _get_factory(merge_with, kwargs)
61
+
62
+ result = factory()
63
+ for d in dicts:
64
+ for k, v in d.items():
65
+ if k not in result:
66
+ result[k] = [v]
67
+ else:
68
+ result[k].append(v)
69
+ return valmap(func, result, factory)
70
+
71
+
72
+ def valmap(func, d, factory=dict):
73
+ """ Apply function to values of dictionary
74
+
75
+ >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
76
+ >>> valmap(sum, bills) # doctest: +SKIP
77
+ {'Alice': 65, 'Bob': 45}
78
+
79
+ See Also:
80
+ keymap
81
+ itemmap
82
+ """
83
+ rv = factory()
84
+ rv.update(zip(d.keys(), map(func, d.values())))
85
+ return rv
86
+
87
+
88
+ def keymap(func, d, factory=dict):
89
+ """ Apply function to keys of dictionary
90
+
91
+ >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
92
+ >>> keymap(str.lower, bills) # doctest: +SKIP
93
+ {'alice': [20, 15, 30], 'bob': [10, 35]}
94
+
95
+ See Also:
96
+ valmap
97
+ itemmap
98
+ """
99
+ rv = factory()
100
+ rv.update(zip(map(func, d.keys()), d.values()))
101
+ return rv
102
+
103
+
104
+ def itemmap(func, d, factory=dict):
105
+ """ Apply function to items of dictionary
106
+
107
+ >>> accountids = {"Alice": 10, "Bob": 20}
108
+ >>> itemmap(reversed, accountids) # doctest: +SKIP
109
+ {10: "Alice", 20: "Bob"}
110
+
111
+ See Also:
112
+ keymap
113
+ valmap
114
+ """
115
+ rv = factory()
116
+ rv.update(map(func, d.items()))
117
+ return rv
118
+
119
+
120
+ def valfilter(predicate, d, factory=dict):
121
+ """ Filter items in dictionary by value
122
+
123
+ >>> iseven = lambda x: x % 2 == 0
124
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
125
+ >>> valfilter(iseven, d)
126
+ {1: 2, 3: 4}
127
+
128
+ See Also:
129
+ keyfilter
130
+ itemfilter
131
+ valmap
132
+ """
133
+ rv = factory()
134
+ for k, v in d.items():
135
+ if predicate(v):
136
+ rv[k] = v
137
+ return rv
138
+
139
+
140
+ def keyfilter(predicate, d, factory=dict):
141
+ """ Filter items in dictionary by key
142
+
143
+ >>> iseven = lambda x: x % 2 == 0
144
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
145
+ >>> keyfilter(iseven, d)
146
+ {2: 3, 4: 5}
147
+
148
+ See Also:
149
+ valfilter
150
+ itemfilter
151
+ keymap
152
+ """
153
+ rv = factory()
154
+ for k, v in d.items():
155
+ if predicate(k):
156
+ rv[k] = v
157
+ return rv
158
+
159
+
160
+ def itemfilter(predicate, d, factory=dict):
161
+ """ Filter items in dictionary by item
162
+
163
+ >>> def isvalid(item):
164
+ ... k, v = item
165
+ ... return k % 2 == 0 and v < 4
166
+
167
+ >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
168
+ >>> itemfilter(isvalid, d)
169
+ {2: 3}
170
+
171
+ See Also:
172
+ keyfilter
173
+ valfilter
174
+ itemmap
175
+ """
176
+ rv = factory()
177
+ for item in d.items():
178
+ if predicate(item):
179
+ k, v = item
180
+ rv[k] = v
181
+ return rv
182
+
183
+
184
+ def assoc(d, key, value, factory=dict):
185
+ """ Return a new dict with new key value pair
186
+
187
+ New dict has d[key] set to value. Does not modify the initial dictionary.
188
+
189
+ >>> assoc({'x': 1}, 'x', 2)
190
+ {'x': 2}
191
+ >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
192
+ {'x': 1, 'y': 3}
193
+ """
194
+ d2 = factory()
195
+ d2.update(d)
196
+ d2[key] = value
197
+ return d2
198
+
199
+
200
+ def dissoc(d, *keys, **kwargs):
201
+ """ Return a new dict with the given key(s) removed.
202
+
203
+ New dict has d[key] deleted for each supplied key.
204
+ Does not modify the initial dictionary.
205
+
206
+ >>> dissoc({'x': 1, 'y': 2}, 'y')
207
+ {'x': 1}
208
+ >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
209
+ {}
210
+ >>> dissoc({'x': 1}, 'y') # Ignores missing keys
211
+ {'x': 1}
212
+ """
213
+ factory = _get_factory(dissoc, kwargs)
214
+ d2 = factory()
215
+
216
+ if len(keys) < len(d) * .6:
217
+ d2.update(d)
218
+ for key in keys:
219
+ if key in d2:
220
+ del d2[key]
221
+ else:
222
+ remaining = set(d)
223
+ remaining.difference_update(keys)
224
+ for k in remaining:
225
+ d2[k] = d[k]
226
+ return d2
227
+
228
+
229
+ def assoc_in(d, keys, value, factory=dict):
230
+ """ Return a new dict with new, potentially nested, key value pair
231
+
232
+ >>> purchase = {'name': 'Alice',
233
+ ... 'order': {'items': ['Apple', 'Orange'],
234
+ ... 'costs': [0.50, 1.25]},
235
+ ... 'credit card': '5555-1234-1234-1234'}
236
+ >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
237
+ {'credit card': '5555-1234-1234-1234',
238
+ 'name': 'Alice',
239
+ 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
240
+ """
241
+ return update_in(d, keys, lambda x: value, value, factory)
242
+
243
+
244
+ def update_in(d, keys, func, default=None, factory=dict):
245
+ """ Update value in a (potentially) nested dictionary
246
+
247
+ inputs:
248
+ d - dictionary on which to operate
249
+ keys - list or tuple giving the location of the value to be changed in d
250
+ func - function to operate on that value
251
+
252
+ If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
253
+ original dictionary with v replaced by func(v), but does not mutate the
254
+ original dictionary.
255
+
256
+ If k0 is not a key in d, update_in creates nested dictionaries to the depth
257
+ specified by the keys, with the innermost value set to func(default).
258
+
259
+ >>> inc = lambda x: x + 1
260
+ >>> update_in({'a': 0}, ['a'], inc)
261
+ {'a': 1}
262
+
263
+ >>> transaction = {'name': 'Alice',
264
+ ... 'purchase': {'items': ['Apple', 'Orange'],
265
+ ... 'costs': [0.50, 1.25]},
266
+ ... 'credit card': '5555-1234-1234-1234'}
267
+ >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
268
+ {'credit card': '5555-1234-1234-1234',
269
+ 'name': 'Alice',
270
+ 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
271
+
272
+ >>> # updating a value when k0 is not in d
273
+ >>> update_in({}, [1, 2, 3], str, default="bar")
274
+ {1: {2: {3: 'bar'}}}
275
+ >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
276
+ {1: 'foo', 2: {3: {4: 1}}}
277
+ """
278
+ ks = iter(keys)
279
+ k = next(ks)
280
+
281
+ rv = inner = factory()
282
+ rv.update(d)
283
+
284
+ for key in ks:
285
+ if k in d:
286
+ d = d[k]
287
+ dtemp = factory()
288
+ dtemp.update(d)
289
+ else:
290
+ d = dtemp = factory()
291
+
292
+ inner[k] = inner = dtemp
293
+ k = key
294
+
295
+ if k in d:
296
+ inner[k] = func(d[k])
297
+ else:
298
+ inner[k] = func(default)
299
+ return rv
300
+
301
+
302
+ def get_in(keys, coll, default=None, no_default=False):
303
+ """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
304
+
305
+ If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
306
+ ``no_default`` is specified, then it raises KeyError or IndexError.
307
+
308
+ ``get_in`` is a generalization of ``operator.getitem`` for nested data
309
+ structures such as dictionaries and lists.
310
+
311
+ >>> transaction = {'name': 'Alice',
312
+ ... 'purchase': {'items': ['Apple', 'Orange'],
313
+ ... 'costs': [0.50, 1.25]},
314
+ ... 'credit card': '5555-1234-1234-1234'}
315
+ >>> get_in(['purchase', 'items', 0], transaction)
316
+ 'Apple'
317
+ >>> get_in(['name'], transaction)
318
+ 'Alice'
319
+ >>> get_in(['purchase', 'total'], transaction)
320
+ >>> get_in(['purchase', 'items', 'apple'], transaction)
321
+ >>> get_in(['purchase', 'items', 10], transaction)
322
+ >>> get_in(['purchase', 'total'], transaction, 0)
323
+ 0
324
+ >>> get_in(['y'], {}, no_default=True)
325
+ Traceback (most recent call last):
326
+ ...
327
+ KeyError: 'y'
328
+
329
+ See Also:
330
+ itertoolz.get
331
+ operator.getitem
332
+ """
333
+ try:
334
+ return reduce(operator.getitem, keys, coll)
335
+ except (KeyError, IndexError, TypeError):
336
+ if no_default:
337
+ raise
338
+ return default
339
+
340
+
341
+ def getter(index):
342
+ if isinstance(index, list):
343
+ if len(index) == 1:
344
+ index = index[0]
345
+ return lambda x: (x[index],)
346
+ elif index:
347
+ return operator.itemgetter(*index)
348
+ else:
349
+ return lambda x: ()
350
+ else:
351
+ return operator.itemgetter(index)
352
+
353
+
354
+ def groupby(key, seq):
355
+ """ Group a collection by a key function
356
+
357
+ >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
358
+ >>> groupby(len, names) # doctest: +SKIP
359
+ {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
360
+
361
+ >>> iseven = lambda x: x % 2 == 0
362
+ >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
363
+ {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
364
+
365
+ Non-callable keys imply grouping on a member.
366
+
367
+ >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
368
+ ... {'name': 'Bob', 'gender': 'M'},
369
+ ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
370
+ {'F': [{'gender': 'F', 'name': 'Alice'}],
371
+ 'M': [{'gender': 'M', 'name': 'Bob'},
372
+ {'gender': 'M', 'name': 'Charlie'}]}
373
+
374
+ Not to be confused with ``itertools.groupby``
375
+
376
+ See Also:
377
+ countby
378
+ """
379
+ if not callable(key):
380
+ key = getter(key)
381
+ d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
382
+ for item in seq:
383
+ d[key(item)](item)
384
+ rv = {}
385
+ for k, v in d.items():
386
+ rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
387
+ return rv
388
+
389
+
390
+ def first(seq):
391
+ """ The first element in a sequence
392
+
393
+ >>> first('ABC')
394
+ 'A'
395
+ """
396
+ return next(iter(seq))
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
3
+ def hashable(x):
4
+ try:
5
+ hash(x)
6
+ return True
7
+ except TypeError:
8
+ return False
9
+
10
+
11
+ def transitive_get(key, d):
12
+ """ Transitive dict.get
13
+ >>> d = {1: 2, 2: 3, 3: 4}
14
+ >>> d.get(1)
15
+ 2
16
+ >>> transitive_get(1, d)
17
+ 4
18
+ """
19
+ while hashable(key) and key in d:
20
+ key = d[key]
21
+ return key
22
+
23
+
24
+ def raises(err, lamda):
25
+ try:
26
+ lamda()
27
+ return False
28
+ except err:
29
+ return True
30
+
31
+
32
+ # Taken from theano/theano/gof/sched.py
33
+ # Avoids licensing issues because this was written by Matthew Rocklin
34
+ def _toposort(edges):
35
+ """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
36
+ inputs:
37
+ edges - a dict of the form {a: {b, c}} where b and c depend on a
38
+ outputs:
39
+ L - an ordered list of nodes that satisfy the dependencies of edges
40
+ >>> # xdoctest: +SKIP
41
+ >>> _toposort({1: (2, 3), 2: (3, )})
42
+ [1, 2, 3]
43
+ Closely follows the wikipedia page [2]
44
+ [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
45
+ Communications of the ACM
46
+ [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
47
+ """
48
+ incoming_edges = reverse_dict(edges)
49
+ incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
50
+ S = ({v for v in edges if v not in incoming_edges})
51
+ L = []
52
+
53
+ while S:
54
+ n = S.pop()
55
+ L.append(n)
56
+ for m in edges.get(n, ()):
57
+ assert n in incoming_edges[m]
58
+ incoming_edges[m].remove(n)
59
+ if not incoming_edges[m]:
60
+ S.add(m)
61
+ if any(incoming_edges.get(v, None) for v in edges):
62
+ raise ValueError("Input has cycles")
63
+ return L
64
+
65
+
66
+ def reverse_dict(d):
67
+ """Reverses direction of dependence dict
68
+ >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
69
+ >>> reverse_dict(d) # doctest: +SKIP
70
+ {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
71
+ :note: dict order are not deterministic. As we iterate on the
72
+ input dict, it make the output of this function depend on the
73
+ dict order. So this function output order should be considered
74
+ as undeterministic.
75
+ """
76
+ result = {} # type: ignore[var-annotated]
77
+ for key in d:
78
+ for val in d[key]:
79
+ result[val] = result.get(val, ()) + (key,)
80
+ return result
81
+
82
+
83
+ def xfail(func):
84
+ try:
85
+ func()
86
+ raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002
87
+ except Exception:
88
+ pass
89
+
90
+
91
+ def freeze(d):
92
+ """ Freeze container to hashable form
93
+ >>> freeze(1)
94
+ 1
95
+ >>> freeze([1, 2])
96
+ (1, 2)
97
+ >>> freeze({1: 2}) # doctest: +SKIP
98
+ frozenset([(1, 2)])
99
+ """
100
+ if isinstance(d, dict):
101
+ return frozenset(map(freeze, d.items()))
102
+ if isinstance(d, set):
103
+ return frozenset(map(freeze, d))
104
+ if isinstance(d, (tuple, list)):
105
+ return tuple(map(freeze, d))
106
+ return d
.venv/lib/python3.11/site-packages/torch/fx/experimental/unification/variable.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mypy: allow-untyped-defs
2
+ from contextlib import contextmanager
3
+ from .utils import hashable
4
+ from .dispatch import dispatch
5
+
6
+ _global_logic_variables = set() # type: ignore[var-annotated]
7
+ _glv = _global_logic_variables
8
+
9
+
10
+ class Var:
11
+ """ Logic Variable """
12
+
13
+ _id = 1
14
+
15
+ def __new__(cls, *token):
16
+ if len(token) == 0:
17
+ token = f"_{Var._id}" # type: ignore[assignment]
18
+ Var._id += 1
19
+ elif len(token) == 1:
20
+ token = token[0]
21
+
22
+ obj = object.__new__(cls)
23
+ obj.token = token # type: ignore[attr-defined]
24
+ return obj
25
+
26
+ def __str__(self):
27
+ return "~" + str(self.token) # type: ignore[attr-defined]
28
+ __repr__ = __str__
29
+
30
+ def __eq__(self, other):
31
+ return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined]
32
+
33
+ def __hash__(self):
34
+ return hash((type(self), self.token)) # type: ignore[attr-defined]
35
+
36
+
37
+ def var():
38
+ return lambda *args: Var(*args)
39
+
40
+
41
+ def vars():
42
+ return lambda n: [var() for i in range(n)]
43
+
44
+
45
+ @dispatch(Var)
46
+ def isvar(v):
47
+ return True
48
+
49
+ isvar
50
+
51
+
52
+ @dispatch(object) # type: ignore[no-redef]
53
+ def isvar(o):
54
+ return not not _glv and hashable(o) and o in _glv
55
+
56
+
57
+ @contextmanager
58
+ def variables(*variables):
59
+ """
60
+ Context manager for logic variables
61
+
62
+ Example:
63
+ >>> # xdoctest: +SKIP("undefined vars")
64
+ >>> from __future__ import with_statement
65
+ >>> with variables(1):
66
+ ... print(isvar(1))
67
+ True
68
+ >>> print(isvar(1))
69
+ False
70
+ >>> # Normal approach
71
+ >>> from unification import unify
72
+ >>> x = var('x')
73
+ >>> unify(x, 1)
74
+ {~x: 1}
75
+ >>> # Context Manager approach
76
+ >>> with variables('x'):
77
+ ... print(unify('x', 1))
78
+ {'x': 1}
79
+ """
80
+ old_global_logic_variables = _global_logic_variables.copy()
81
+ _global_logic_variables.update(set(variables))
82
+ try:
83
+ yield
84
+ finally:
85
+ _global_logic_variables.clear()
86
+ _global_logic_variables.update(old_global_logic_variables)
.venv/lib/python3.11/site-packages/torch/fx/passes/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import graph_drawer
2
+ from . import graph_manipulation
3
+ from . import net_min_base
4
+ from . import operator_support
5
+ from . import param_fetch
6
+ from . import reinplace
7
+ from . import runtime_assert
8
+ from . import shape_prop
9
+ from . import split_module
10
+ from . import split_utils
11
+ from . import splitter_base
12
+ from . import tools_common
.venv/lib/python3.11/site-packages/torch/fx/passes/annotate_getitem_nodes.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+
3
+ import torch
4
+
5
+
6
+ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
7
+ """
8
+ Annotate the type of getitem nodes, inferred from the type of sequence node.
9
+ If sequence node is not annotated with a type, do nothing.
10
+ Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
11
+
12
+ This is helpful since annotations on local names within function are lost during FX transforms.
13
+ Adding back known type annotation for getitem nodes to improve jit scriptability.
14
+
15
+ Args:
16
+ graph (Graph): The graph to be annotated
17
+ """
18
+ for node in graph.nodes:
19
+ if node.target == operator.getitem:
20
+ sequence_node, index_node = node.args
21
+ if not sequence_node.type:
22
+ continue
23
+ # container types
24
+ if hasattr(sequence_node.type, "_name"):
25
+ parameterized_types = sequence_node.type.__args__
26
+ if sequence_node.type._name == "Tuple":
27
+ if len(parameterized_types) == 2 and isinstance(
28
+ parameterized_types[1], type(...)
29
+ ):
30
+ node.type = parameterized_types[0]
31
+ else:
32
+ assert len(parameterized_types) > index_node
33
+ node_type = parameterized_types[index_node]
34
+ node.type = node_type
35
+ elif sequence_node.type._name == "List":
36
+ assert len(parameterized_types) == 1
37
+ node.type = parameterized_types[0]
38
+ # NamedTuple type
39
+ elif hasattr(sequence_node.type, "__annotations__"):
40
+ if sequence_node.type == torch.Tensor:
41
+ continue
42
+ sequence_node_field_types = sequence_node.type.__annotations__
43
+ field_name = sequence_node.type._fields[index_node]
44
+ node.type = sequence_node_field_types[field_name]
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__init__.py ADDED
File without changes
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
.venv/lib/python3.11/site-packages/torch/fx/passes/dialect/common/__init__.py ADDED
File without changes