xiaoanyu123 commited on
Commit
b5fdc16
·
verified ·
1 Parent(s): f134ab5

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/onnxscript/_framework_apis/torch_2_5.py +117 -0
  2. pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset10.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset12.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/basic_rules.py +321 -0
  5. pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/broadcast_to_matmul.py +178 -0
  6. pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/cast_constant_of_shape.py +46 -0
  7. pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/collapse_slices.py +107 -0
  8. pythonProject/.venv/Lib/site-packages/onnxscript/utils/__init__.py +0 -0
  9. pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  10. pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/evaluation_utils.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/timing_utils.cpython-310.pyc +0 -0
  12. pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/utils.cpython-310.pyc +0 -0
  13. pythonProject/.venv/Lib/site-packages/onnxscript/utils/timing_utils.py +33 -0
  14. pythonProject/.venv/Lib/site-packages/onnxscript/utils/utils.py +84 -0
  15. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__init__.py +179 -0
  16. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/__init__.cpython-310.pyc +0 -0
  17. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_c_api_utils.cpython-310.pyc +0 -0
  18. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_version_converter.cpython-310.pyc +0 -0
  19. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_c_api_utils.py +77 -0
  20. pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_version_converter.py +339 -0
pythonProject/.venv/Lib/site-packages/onnxscript/_framework_apis/torch_2_5.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Stable APIs for PyTorch 2.5."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "check_model",
9
+ "convert_version",
10
+ "get_torchlib_ops",
11
+ "optimize",
12
+ "save_model_with_external_data",
13
+ ]
14
+
15
+ import dataclasses
16
+ import os
17
+ import pathlib
18
+ from typing import Callable
19
+
20
+ from onnxscript import ir, optimizer, version_converter
21
+ from onnxscript.function_libs.torch_lib import registration
22
+
23
+
24
+ @dataclasses.dataclass(frozen=True)
25
+ class _OnnxFunctionMeta:
26
+ """A wrapper of onnx-script function with additional metadata.
27
+
28
+ qualified_name: The qualified name of the aten operator.
29
+ function: The onnx-script function.
30
+ domain: The domain of the function.
31
+ name: The name of the function.
32
+ is_complex: Whether the function is a complex function.
33
+ """
34
+
35
+ qualified_name: str
36
+ function: Callable
37
+ domain: str
38
+ name: str
39
+ is_complex: bool = False
40
+
41
+
42
+ def optimize(model: ir.Model) -> ir.Model:
43
+ """Optimize the model."""
44
+ # Internal flag. Will go away.
45
+ enabled = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1"
46
+ if enabled:
47
+ optimizer.optimize_ir(model)
48
+ return model
49
+
50
+
51
+ def convert_version(model: ir.Model, target_version: int) -> ir.Model:
52
+ """Convert the model to the specified ONNX opset version."""
53
+ # Internal flag. Will go away.
54
+ enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1"
55
+ if enabled:
56
+ version_converter.convert_version(model, target_version)
57
+ return model
58
+
59
+
60
+ def check_model(model: ir.Model) -> None:
61
+ """Check the model."""
62
+
63
+ del model # Unused yet
64
+
65
+
66
+ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
67
+ """Save the model with external data. The model is unchanged after saving."""
68
+
69
+ # TODO(#1835): Decide if we want to externalize large attributes as well
70
+ for value in model.graph.initializers.values():
71
+ if value.const_value is None:
72
+ raise ValueError(
73
+ "The model contains uninitialized initializer values. "
74
+ "Please make sure all initializer values are initialized."
75
+ )
76
+ destination_path = pathlib.Path(model_path)
77
+ data_path = f"{destination_path.name}.data"
78
+
79
+ ir.save(model, model_path, external_data=data_path)
80
+
81
+
82
+ def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
83
+ # Trigger op registration
84
+ from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
85
+ ops,
86
+ )
87
+
88
+ del ops # Unused
89
+
90
+ torchlib_registry = registration.default_registry
91
+ function_metas = []
92
+
93
+ for qualified_name, aten_overloads_func in torchlib_registry.items():
94
+ if qualified_name.startswith("internal::"):
95
+ # Skip the custom defined internal functions
96
+ continue
97
+
98
+ for overload_func in aten_overloads_func.overloads:
99
+ function_meta = _OnnxFunctionMeta(
100
+ qualified_name=qualified_name,
101
+ function=overload_func,
102
+ domain=overload_func.function_ir.domain,
103
+ name=overload_func.name,
104
+ is_complex=False,
105
+ )
106
+ function_metas.append(function_meta)
107
+ for complex_func in aten_overloads_func.complex:
108
+ function_meta = _OnnxFunctionMeta(
109
+ qualified_name=qualified_name,
110
+ function=complex_func,
111
+ domain=complex_func.function_ir.domain,
112
+ name=complex_func.name,
113
+ is_complex=True,
114
+ )
115
+ function_metas.append(function_meta)
116
+
117
+ return function_metas
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset10.cpython-310.pyc ADDED
Binary file (49.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset12.cpython-310.pyc ADDED
Binary file (40 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/basic_rules.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Basic rewrite rules for general optimization patterns.
4
+
5
+ This module contains fundamental optimization rules that are generally applicable
6
+ to most ONNX models, including cast elimination, transpose simplification,
7
+ shape operation fusion, and other common patterns.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import ClassVar, Sequence
13
+
14
+ from onnxscript import ir
15
+ from onnxscript.rewriter import _ir_utils as ir_utils
16
+ from onnxscript.rewriter._basics import MatchResult
17
+ from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
18
+
19
+
20
+ class SqueezeReshape(RewriteRuleClassBase):
21
+ """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
22
+
23
+ This pattern arises from the translation of pytorch symints.
24
+ """
25
+
26
+ def __init__(self):
27
+ super().__init__("SqueezeReshape1d", remove_nodes=False)
28
+
29
+ def pattern(self, op, x):
30
+ return op.Reshape(op.Squeeze(x), [-1])
31
+
32
+ def rewrite(self, op, x: ir.Value):
33
+ return op.Identity(x)
34
+
35
+ def check(self, context, x) -> MatchResult:
36
+ del context # Unused
37
+ check_result = MatchResult()
38
+ if not ir_utils.has_rank(x, 1):
39
+ return check_result.fail("Input is not 1D")
40
+ return check_result
41
+
42
+
43
+ class CastIdentity(RewriteRuleClassBase):
44
+ """Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
45
+
46
+ def pattern(self, op, x, to):
47
+ return op.Cast(x, to=to)
48
+
49
+ def rewrite(self, op, x: ir.Value, to: ir.Attr):
50
+ return op.Identity(x)
51
+
52
+ def check(self, context, x, to) -> MatchResult:
53
+ check_result = MatchResult()
54
+ if x.dtype != to.as_int():
55
+ return check_result.fail("Input and output types are not the same")
56
+ return check_result
57
+
58
+
59
+ class CastCast(RewriteRuleClassBase):
60
+ """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
61
+
62
+ # Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
63
+ # This rule is not valid for all combinations of types: e.g.,
64
+ # it is not valid for float32 => float16 => float32 or float32 => int32 => string.
65
+ # TODO: fill out the list of allowed combinations: the following is just a couple
66
+ # that shows up in practice where it is valid
67
+ _allowed_type2_type3: ClassVar = frozenset(
68
+ {
69
+ (ir.DataType.FLOAT, ir.DataType.FLOAT16),
70
+ (ir.DataType.FLOAT, ir.DataType.BFLOAT16),
71
+ }
72
+ )
73
+
74
+ def pattern(self, op, x, to, to_ignored):
75
+ return op.Cast(op.Cast(x, to=to_ignored), to=to)
76
+
77
+ def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult:
78
+ check_result = MatchResult()
79
+ type2 = to_ignored.as_int()
80
+ type3 = to.as_int()
81
+ if (type2, type3) not in self._allowed_type2_type3:
82
+ return check_result.fail(
83
+ f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
84
+ f"Cast-Cast rule may be incomplete for this combination."
85
+ )
86
+ return check_result
87
+
88
+ def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
89
+ return op.Cast(x, to=to)
90
+
91
+
92
+ class ExpandIdentity(RewriteRuleClassBase):
93
+ """Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
94
+
95
+ def pattern(self, op, x, shape):
96
+ return op.Expand(x, shape)
97
+
98
+ def rewrite(self, op, x: ir.Value, shape: ir.Value):
99
+ return op.Identity(x)
100
+
101
+ def check(self, context, x, shape) -> MatchResult:
102
+ check_result = MatchResult()
103
+ if shape.const_value is None:
104
+ # Shape is not a constant and cannot be guessed.
105
+ return check_result.fail("Shape is not a constant and cannot be guessed.")
106
+ if (x_shape := x.shape) is None:
107
+ # We don't know the shape of the input
108
+ return check_result.fail("Input shape is not known.")
109
+ if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
110
+ return check_result.fail(
111
+ f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
112
+ )
113
+ return check_result
114
+
115
+
116
+ class ReshapeReshape(RewriteRuleClassBase):
117
+ """Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
118
+ The pattern matches only if second reshape reshapes into a shape
119
+ with positive values.
120
+ """
121
+
122
+ def pattern(self, op, x, shape_ignored, shape):
123
+ return op.Reshape(op.Reshape(x, shape_ignored), shape)
124
+
125
+ def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
126
+ return op.Reshape(x, shape)
127
+
128
+ def check(self, context, x, shape_ignored, shape) -> MatchResult:
129
+ check_result = MatchResult()
130
+ if shape_ignored.const_value is None:
131
+ return check_result.fail("Shape ignored is not a constant.")
132
+ if shape.const_value is None:
133
+ return check_result.fail("Shape is not a constant.")
134
+ if shape.const_value.numpy().min() <= 0:
135
+ return check_result.fail("Shape has non-positive values.")
136
+ return check_result
137
+
138
+
139
+ class SlicesSplit(RewriteRuleClassBase):
140
+ """Replaces ``Slice(x, ...), Slice(x, ...)``
141
+ by ``Split(x, ...)`` if possible.
142
+ """
143
+
144
+ def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
145
+ return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
146
+
147
+ def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> MatchResult:
148
+ check_result = MatchResult()
149
+ if (
150
+ axes0.const_value is None
151
+ or axes1.const_value is None
152
+ or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
153
+ ):
154
+ return check_result.fail("Axes are not equal or not constant.")
155
+ axes = axes0.const_value.numpy().tolist()
156
+ if len(axes) != 1:
157
+ return check_result.fail("Axes has more than one dimension.")
158
+ if x.shape:
159
+ rk = len(x.shape)
160
+ else:
161
+ rk = x.rank
162
+ if axes[0] != -1 and axes[0] != rk - 1:
163
+ return check_result.fail("Axes is not -1 or last dimension.")
164
+ if (
165
+ begin0.const_value is None
166
+ or end0.const_value is None
167
+ or begin1.const_value is None
168
+ or end1.const_value is None
169
+ ):
170
+ return check_result.fail("Begin or end are not constant values.")
171
+ if begin0.const_value.numpy().tolist() != [0]:
172
+ return check_result.fail("First begin value is not 0.")
173
+ e0, b1, e1 = (
174
+ end0.const_value.numpy().tolist(),
175
+ begin1.const_value.numpy().tolist(),
176
+ end1.const_value.numpy().tolist(),
177
+ )
178
+ if e0[0] != b1[0]:
179
+ return check_result.fail("End0 is not equal to Begin1.")
180
+ shape = x.shape
181
+ if shape is None:
182
+ return check_result.fail("Shape is not known.")
183
+ last_dim = shape[-1]
184
+ if not isinstance(last_dim, int):
185
+ return check_result.fail("Last dimension is not known.")
186
+ if last_dim != e1[0]:
187
+ return check_result.fail("Last dimension is not equal to End1.")
188
+ if last_dim // 2 != b1[0]:
189
+ return check_result.fail("Last dimension is not equal to Begin1.")
190
+ return check_result
191
+
192
+ def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
193
+ return op.Split(x, num_outputs=2, axis=-1, _outputs=2)
194
+
195
+
196
+ class TransposeIdentity(RewriteRuleClassBase):
197
+ """Replaces ``Transpose(. perm=perm)``
198
+ when the permutation is identity.
199
+ """
200
+
201
+ def pattern(self, op, x, perm):
202
+ return op.Transpose(x, perm=perm)
203
+
204
+ def check(self, context, x: ir.Value, perm: ir.Attr) -> MatchResult:
205
+ check_result = MatchResult()
206
+ if perm.is_ref():
207
+ return check_result.fail("Permutation is a reference attribute.")
208
+ if perm.type == ir.AttributeType.INTS:
209
+ perm_ints = tuple(perm.as_ints())
210
+ if perm_ints == tuple(range(len(perm_ints))):
211
+ return check_result
212
+ return check_result.fail("Permutation is not identity.")
213
+
214
+ def rewrite(self, op, x: ir.Value, perm: ir.Attr):
215
+ return op.Identity(x)
216
+
217
+
218
+ class TransposeTranspose(RewriteRuleClassBase):
219
+ """Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
220
+ when both permutations are inverse.
221
+ """
222
+
223
+ def pattern(self, op, x, perm1, perm2):
224
+ return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
225
+
226
+ def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> MatchResult:
227
+ check_result = MatchResult()
228
+ if perm1.is_ref() or perm2.is_ref():
229
+ return check_result.fail("Permutation is a reference attribute.")
230
+ return check_result
231
+
232
+ def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
233
+ assert len(perm) == len(on), "length mismatch"
234
+ res = [-1 for i in on]
235
+ for i, p in enumerate(perm):
236
+ res[i] = on[p]
237
+ return res
238
+
239
+ def _apply_transposes(
240
+ self, perms: list[Sequence[int]], on: list[int] | None = None
241
+ ) -> list[int]:
242
+ if on is None:
243
+ on = list(range(len(perms[0])))
244
+ for p in perms:
245
+ on = self._apply_transpose(p, on)
246
+ return on
247
+
248
+ def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
249
+ first = list(range(len(perm1.as_ints())))
250
+ last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
251
+ if first == last:
252
+ return op.Identity(x)
253
+ return op.Transpose(x, perm=last)
254
+
255
+
256
+ class UnsqueezeUnsqueeze(RewriteRuleClassBase):
257
+ """Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
258
+
259
+ def pattern(self, op, x, axes1, axes2):
260
+ return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)
261
+
262
+ def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
263
+ v1 = ir_utils.get_singleton_value(axes1)
264
+ v2 = ir_utils.get_singleton_value(axes2)
265
+ axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
266
+ return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
267
+
268
+ def check(self, context, x, axes1, axes2) -> MatchResult:
269
+ check_result = MatchResult()
270
+ del context # Unused
271
+ del x # Unused
272
+ # Currently restricted to single element positive axis
273
+ v1 = ir_utils.get_singleton_value(axes1)
274
+ v2 = ir_utils.get_singleton_value(axes2)
275
+ if v1 is None or v2 is None:
276
+ return check_result.fail("Axes are not constant.")
277
+ if (v1 < 0) or (v2 < 0):
278
+ return check_result.fail("Axes are negative.")
279
+ return check_result
280
+
281
+
282
+ # Create rule instances
283
+ cast_cast_rule = CastCast.rule()
284
+ cast_identity_rule = CastIdentity.rule()
285
+ expand_identity_rule = ExpandIdentity.rule()
286
+ reshape_reshape_rule = ReshapeReshape.rule()
287
+ slice_split_rule = SlicesSplit.rule()
288
+ transpose_identity_rule = TransposeIdentity.rule()
289
+ transpose_transpose_rule = TransposeTranspose.rule()
290
+ unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
291
+ squeeze_reshape_1d_rule = SqueezeReshape.rule()
292
+
293
+
294
+ def basic_optimization_rules() -> RewriteRuleSet:
295
+ """Returns a set of basic optimization rules.
296
+
297
+ These rules perform fundamental optimizations such as:
298
+ - Eliminating redundant cast operations
299
+ - Simplifying consecutive operations of the same type
300
+ - Removing identity operations
301
+ - Optimizing shape manipulation operations
302
+
303
+ These rules are generally safe to apply as a first optimization pass
304
+ before other more specialized optimizations.
305
+
306
+ Returns:
307
+ RewriteRuleSet: A collection of basic optimization rules
308
+ """
309
+ return RewriteRuleSet(
310
+ [
311
+ cast_cast_rule,
312
+ cast_identity_rule,
313
+ expand_identity_rule,
314
+ reshape_reshape_rule,
315
+ slice_split_rule,
316
+ transpose_identity_rule,
317
+ transpose_transpose_rule,
318
+ unsqueeze_unsqueeze_rule,
319
+ squeeze_reshape_1d_rule,
320
+ ]
321
+ )
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/broadcast_to_matmul.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from onnxscript import ir
8
+ from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def check_if_not_need_reshape(
14
+ context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
15
+ ) -> bool:
16
+ """Condition to check if we need to replace the pattern.
17
+
18
+ If matmul broadcasting is enough, then we don't need the reshapes.
19
+
20
+ To validate this, we need to check the following:
21
+ 1. Input shapes check: input_a and input_b should be broadcastable
22
+ 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
23
+
24
+ If the above are true, then we don't need the reshapes.
25
+
26
+ Returns:
27
+ True if we need to replace the pattern, False otherwise.
28
+ """
29
+ del context # Reserved for future extensions
30
+
31
+ input_a_shape = input_a.shape
32
+ input_b_shape = input_b.shape
33
+ shape_c_tensor = shape_c.const_value
34
+ if shape_c_tensor is None:
35
+ logger.info("The value 'shape_c' is not statically known.")
36
+ return False
37
+
38
+ if len(shape_c_tensor.shape) != 1:
39
+ logger.info(
40
+ "Unexpected final shape. The shape of 'shape' value is %s",
41
+ shape_c_tensor.shape,
42
+ )
43
+ return False
44
+
45
+ # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
46
+ # information. So, we need to check if the shape is None and return False.
47
+ if input_a_shape is None or input_b_shape is None:
48
+ logger.info("Shape information is not available for the inputs and outputs.")
49
+ return False
50
+ if any(isinstance(dim, ir.SymbolicDim) for dim in input_a_shape):
51
+ logger.info("Symbolic dimensions are not yet supported.")
52
+ return False
53
+ if any(isinstance(dim, ir.SymbolicDim) for dim in input_b_shape):
54
+ logger.info("Symbolic dimensions are not yet supported.")
55
+ return False
56
+ input_a_shape = input_a_shape.numpy() # type: ignore[assignment]
57
+ input_b_shape = input_b_shape.numpy() # type: ignore[assignment]
58
+ shape_c = shape_c_tensor.numpy().tolist() # type: ignore[assignment]
59
+
60
+ a_rank = len(input_a_shape)
61
+ b_rank = len(input_b_shape)
62
+
63
+ # 1. Check if input shapes are broadcastable
64
+ # 1.a. If the first input is 1-D, check whether
65
+ # the dim matches the last second dim of the second input.
66
+ mimic_matmul_broadcast_behavior_a = False
67
+ mimic_matmul_broadcast_behavior_b = False
68
+ if a_rank < 2:
69
+ if b_rank < 2:
70
+ logger.info("Optimization of dot product is not supported yet.")
71
+ return False
72
+ if input_a_shape[-1] != input_b_shape[-2]:
73
+ logger.info("Original shape is not MatMul compatible.")
74
+ return False
75
+ else:
76
+ input_a_shape = [1, *input_a_shape] # type: ignore[assignment]
77
+ a_rank = len(input_a_shape)
78
+ mimic_matmul_broadcast_behavior_a = True
79
+ # 1.b. If the second input is 1-D, check whether
80
+ # the dim matches the last dim of the first input.
81
+ if b_rank < 2:
82
+ if input_b_shape[-1] != input_a_shape[-1]:
83
+ logger.info("Original shape is not MatMul compatible.")
84
+ return False
85
+ else:
86
+ input_b_shape = [*input_b_shape, 1] # type: ignore[assignment]
87
+ b_rank = len(input_b_shape)
88
+ mimic_matmul_broadcast_behavior_b = True
89
+ # 1.c. If both inputs are at least 2-D, check whether
90
+ # the last dimension of the first input matches the second
91
+ # last dimension of the second input, and shape[:-2] are
92
+ # broadcastable.
93
+ input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
94
+ input_b_shape_except_last_dim = input_b_shape[:-1]
95
+ broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
96
+ for idx, (dim_from_a, dim_from_b) in enumerate(
97
+ zip(
98
+ reversed(input_a_shape_except_second_last_dim),
99
+ reversed(input_b_shape_except_last_dim),
100
+ )
101
+ ):
102
+ if dim_from_a not in {1, dim_from_b}:
103
+ logger.info("Original shape is not broadcastable.")
104
+ return False
105
+ elif idx > 0:
106
+ broadcast_matmul_output_shape = [
107
+ max(dim_from_a, dim_from_b), # type: ignore[type-var]
108
+ *broadcast_matmul_output_shape,
109
+ ]
110
+
111
+ # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
112
+ # Prepend the broadcast_matmul_output_shape with the longer shape of input
113
+ if a_rank > b_rank:
114
+ longer_shape = input_a_shape
115
+ shorter_shape = input_b_shape
116
+ else:
117
+ longer_shape = input_b_shape
118
+ shorter_shape = input_a_shape
119
+ broadcast_matmul_output_shape = [
120
+ *longer_shape[: -len(shorter_shape)],
121
+ *broadcast_matmul_output_shape,
122
+ ]
123
+ if mimic_matmul_broadcast_behavior_b and b_rank == 2 and input_b_shape[-1] == 1:
124
+ # If input_b is expanded to 2-D, then we need to remove the last dimension
125
+ broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
126
+ if mimic_matmul_broadcast_behavior_a and a_rank == 2 and input_a_shape[0] == 1:
127
+ # If input_a is expanded to 2-D, then we need to remove the first dimension
128
+ # of input_a, which would be the -2nd dimension of the output shape.
129
+ broadcast_matmul_output_shape.pop(-2)
130
+ if shape_c != broadcast_matmul_output_shape:
131
+ logger.info(
132
+ "Final output shape is not the same. Expected %s vs actual %s",
133
+ shape_c,
134
+ broadcast_matmul_output_shape,
135
+ )
136
+ return False
137
+
138
+ return True
139
+
140
+
141
+ def _two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
142
+ # TODO: Modified from `value_ints` to `value` to match pattern in benchmark models.
143
+ # This implementation misses pattern of Constants with `value_ints` attribute.
144
+ # See more at https://github.com/microsoft/onnx-rewriter/issues/191.
145
+ # A better solution is to improve pattern matching and avoid depending on writing
146
+ # Constants in pattern. See https://github.com/microsoft/onnx-rewriter/issues/192.
147
+ reshape_a = op.Reshape(input_a, shape_a)
148
+ reshape_b = op.Reshape(input_b, shape_b)
149
+ matmul = op.MatMul(reshape_a, reshape_b)
150
+ return op.Reshape(matmul, shape_c)
151
+
152
+
153
+ def _matmul(op, input_a, input_b, **_):
154
+ return op.MatMul(input_a, input_b)
155
+
156
+
157
+ def _one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c):
158
+ reshape_a = op.Reshape(input_a, shape_a)
159
+ matmul = op.MatMul(reshape_a, input_b)
160
+ return op.Reshape(matmul, shape_c)
161
+
162
+
163
+ # Register the rewrite rules
164
+ two_reshapes_matmul_reshape_rule = RewriteRule(
165
+ _two_reshapes_matmul_reshape_pattern,
166
+ _matmul,
167
+ check_if_not_need_reshape,
168
+ )
169
+ one_reshape_matmul_reshape_rule = RewriteRule(
170
+ _one_reshape_matmul_reshape_pattern,
171
+ _matmul,
172
+ # We can use the same check_if_not_need_reshape function for both the rules,
173
+ # as one_reshape_matmul_reshape_pattern is a subset of _two_reshapes_matmul_reshape_pattern.
174
+ check_if_not_need_reshape,
175
+ )
176
+
177
+ # NOTE: The order of the rules is important. Larger pattern should be checked first.
178
+ rules = RewriteRuleSet([two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule])
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/cast_constant_of_shape.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from onnxscript import ir
8
+ from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def cast_constant_of_shape(op, shape, scalar, dtype):
14
+ constant = op.ConstantOfShape(shape, value=scalar)
15
+ return op.Cast(constant, to=dtype)
16
+
17
+
18
+ def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_):
19
+ # Cast scalar (a TensorProto attribute) to the specified dtype
20
+ scalar_value = scalar.value.numpy().item()
21
+ cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int()))
22
+ return op.ConstantOfShape(shape, value=cast_value)
23
+
24
+
25
+ def cast_constant_of_shape_without_value(op, shape, dtype):
26
+ constant = op.ConstantOfShape(shape)
27
+ return op.Cast(constant, to=dtype)
28
+
29
+
30
+ def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_):
31
+ zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int()))
32
+ return op.ConstantOfShape(shape, value=zero)
33
+
34
+
35
+ cast_constant_of_shape_rule = RewriteRule(cast_constant_of_shape, fused_cast_constant_of_shape)
36
+
37
+ cast_constant_of_shape_without_value_rule = RewriteRule(
38
+ cast_constant_of_shape_without_value, fused_cast_constant_of_shape_without_value
39
+ )
40
+
41
+ rules = RewriteRuleSet(
42
+ [
43
+ cast_constant_of_shape_rule,
44
+ cast_constant_of_shape_without_value_rule,
45
+ ]
46
+ )
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/collapse_slices.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from onnxscript import ir
8
+ from onnxscript.rewriter._ir_utils import is_singleton_value
9
+ from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
10
+
11
+ logger = logging.getLogger(__name__)
12
+ _INT64_MAX = 9223372036854775807
13
+
14
+
15
+ def _check_if_redundant_slice(
16
+ context,
17
+ data: ir.Value,
18
+ starts: ir.Value,
19
+ ends: ir.Value,
20
+ axes: ir.Value,
21
+ steps: ir.Value,
22
+ **_,
23
+ ) -> bool:
24
+ """If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant."""
25
+ del context # Reserved for future extensions
26
+
27
+ starts_const = starts.const_value
28
+ ends_const = ends.const_value
29
+ axes_const = axes.const_value
30
+ steps_const = steps.const_value
31
+
32
+ if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
33
+ logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
34
+ return False
35
+
36
+ # Check if the values are scalar
37
+ if starts_const.numpy().size != 1: # type: ignore[union-attr]
38
+ logger.info("The value 'start' is not a scalar.")
39
+ return False
40
+ if ends_const.numpy().size != 1: # type: ignore[union-attr]
41
+ logger.info("The value 'end' is not a scalar.")
42
+ return False
43
+ if axes_const.numpy().size != 1: # type: ignore[union-attr]
44
+ logger.info("The value 'axis' is not a scalar.")
45
+ return False
46
+ if steps_const.numpy().size != 1: # type: ignore[union-attr]
47
+ logger.info("The value 'step' is not a scalar.")
48
+ return False
49
+
50
+ if steps_const.numpy().item() != 1:
51
+ logger.info("The value 'step' is not 1.")
52
+ return False
53
+ # starts is 0
54
+ if starts_const.numpy().item() != 0:
55
+ logger.info("The value 'start' is not 0.")
56
+ return False
57
+ # In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
58
+ if ends_const.numpy().item() == _INT64_MAX:
59
+ return True
60
+ if data.shape is None or data.shape.is_dynamic(axes_const.numpy().item()):
61
+ logger.info("The value 'data' shape is not statically known.")
62
+ return False
63
+ if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
64
+ logger.info("The value 'end' is less than the shape of the specified axis.")
65
+ return False
66
+
67
+ return True
68
+
69
+
70
+ def _identity_to_itself(op, data, **_):
71
+ """Return the input data as the output."""
72
+ return op.Identity(data)
73
+
74
+
75
+ def _potential_redundant_slice(op, data, starts, ends, axes, steps):
76
+ """To identify a slice op"""
77
+ return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])
78
+
79
+
80
+ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_):
81
+ """Check if the shape of the slice output is the same as the data."""
82
+ if data.shape is None or slice_output.shape is None:
83
+ return False
84
+
85
+ if not is_singleton_value(steps, 1):
86
+ return False
87
+
88
+ return data.shape == slice_output.shape
89
+
90
+
91
+ # Register the rewrite rules
92
+ remove_redundant_slice = RewriteRule(
93
+ _potential_redundant_slice,
94
+ _identity_to_itself,
95
+ _check_if_redundant_slice,
96
+ )
97
+
98
+ remove_redundant_slice2 = RewriteRule(
99
+ _potential_redundant_slice,
100
+ _identity_to_itself,
101
+ _same_shape,
102
+ )
103
+
104
+ # NOTE: The second rule subsumes the first one. So, we may be able to remove the first one,
105
+ # provided shape-inference is run before the rewriter and computes the shape of the slice output.
106
+
107
+ rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2])
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__init__.py ADDED
File without changes
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/evaluation_utils.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/timing_utils.cpython-310.pyc ADDED
Binary file (853 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.86 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/utils/timing_utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ import time
4
+
5
+ import onnx
6
+
7
+ from onnxscript import optimizer
8
+
9
+
10
+ def timeit(f, message):
11
+ def timed(*args, **kw):
12
+ ts = time.time()
13
+ result = f(*args, **kw)
14
+ te = time.time()
15
+ print(f"{message} time: {te - ts}")
16
+ return result
17
+
18
+ return timed
19
+
20
+
21
+ load = timeit(onnx.load, "Load")
22
+
23
+ save = timeit(onnx.save, "Save")
24
+
25
+ infer = timeit(onnx.shape_inference.infer_shapes, "Infer")
26
+
27
+ fold_constants = timeit(optimizer.fold_constants, "Fold Constants")
28
+
29
+ remove_unused = timeit(optimizer.remove_unused_nodes, "Remove Unused")
30
+
31
+ optimize = timeit(optimizer.optimize, "Optimize")
32
+
33
+ # rewrite = timeit(all_rules.apply_to_model, "Rewrite")
pythonProject/.venv/Lib/site-packages/onnxscript/utils/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import onnx
8
+
9
+
10
+ def normalize_domain(d: str) -> str:
11
+ return "" if d == "ai.onnx" else d
12
+
13
+
14
+ def is_onnx_domain(d: str) -> bool:
15
+ return normalize_domain(d) == ""
16
+
17
+
18
+ def is_onnx_op(node: onnx.NodeProto, op_type: str) -> bool:
19
+ return is_onnx_domain(node.domain) and node.op_type == op_type
20
+
21
+
22
+ def is_control_flow_op(node: onnx.NodeProto) -> bool:
23
+ return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute)
24
+
25
+
26
+ def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any:
27
+ matching = [x for x in node.attribute if x.name == attr_name]
28
+ if len(matching) > 1:
29
+ raise ValueError(f"Node has multiple attributes with name {attr_name}")
30
+ if len(matching) < 1:
31
+ return default
32
+ return onnx.helper.get_attribute_value(matching[0])
33
+
34
+
35
+ def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto:
36
+ type = onnx.TypeProto()
37
+ type.tensor_type.elem_type = initializer.data_type
38
+ dims = type.tensor_type.shape.dim
39
+ for dim in initializer.dims:
40
+ dims.add().dim_value = dim
41
+ return type
42
+
43
+
44
+ def get_constant_node_value(node: onnx.NodeProto, name: str) -> onnx.TensorProto | None:
45
+ if (
46
+ node.op_type != "Constant"
47
+ or node.domain not in {"", "ai.onnx"}
48
+ or len(node.attribute) != 1
49
+ ):
50
+ return None
51
+ attr = node.attribute[0]
52
+ if attr.ref_attr_name:
53
+ return None
54
+ attr_name = attr.name
55
+ value = onnx.helper.get_attribute_value(attr)
56
+
57
+ if isinstance(value, onnx.TensorProto):
58
+ # Two names exist in this case: we use tensorproto as is (with original name)
59
+ return value
60
+ shape: list[int]
61
+ if attr_name == "value_int":
62
+ dtype = onnx.TensorProto.INT64
63
+ shape = []
64
+ value = [value]
65
+ elif attr_name == "value_float":
66
+ dtype = onnx.TensorProto.FLOAT
67
+ shape = []
68
+ value = [value]
69
+ elif attr_name == "value_string":
70
+ dtype = onnx.TensorProto.STRING
71
+ shape = []
72
+ value = [value]
73
+ elif attr_name == "value_ints":
74
+ dtype = onnx.TensorProto.INT64
75
+ shape = [len(value)]
76
+ elif attr_name == "value_floats":
77
+ dtype = onnx.TensorProto.FLOAT
78
+ shape = [len(value)]
79
+ elif attr_name == "value_strings":
80
+ dtype = onnx.TensorProto.STRING
81
+ shape = [len(value)]
82
+ else:
83
+ return None # sparse tensors not handled
84
+ return onnx.helper.make_tensor(name, dtype, shape, value)
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__init__.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "ConvertVersionPass",
7
+ "convert_version",
8
+ ]
9
+
10
+ import logging
11
+
12
+ import onnx
13
+ import onnx_ir.passes.common as common_passes
14
+
15
+ from onnxscript import ir
16
+ from onnxscript.version_converter import _c_api_utils, _version_converter
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ConvertVersionPass(ir.passes.InPlacePass):
22
+ """Convert the model to the specified ONNX opset version.
23
+
24
+ This pass leverages the onnxscript version converter to convert the model. If
25
+ the conversion is not supported, it falls back to the onnx C API to convert
26
+ the model. This pass is in-place.
27
+
28
+ The pass is an no-op if the c-api fails.
29
+
30
+ Attributes:
31
+ target_version: The target ONNX opset version to convert the model to.
32
+ fallback: Whether to fallback to the onnx version converter if the
33
+ target version is not supported. Default is False.
34
+ """
35
+
36
+ def __init__(self, target_version: int, fallback: bool = False) -> None:
37
+ super().__init__()
38
+ self.target_version = target_version
39
+ self.fallback = fallback
40
+ self.convert_pass = ir.passes.Sequential(
41
+ common_passes.InlinePass(),
42
+ _ConvertVersionPassRequiresInline(
43
+ target_version=target_version,
44
+ fallback=fallback,
45
+ ),
46
+ common_passes.RemoveUnusedNodesPass(),
47
+ common_passes.RemoveUnusedFunctionsPass(),
48
+ common_passes.RemoveUnusedOpsetsPass(),
49
+ )
50
+
51
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
52
+ return self.convert_pass(model)
53
+
54
+
55
+ class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
56
+ """Convert the model to the specified ONNX opset version.
57
+
58
+ This pass leverages the onnxscript version converter to convert the model. If
59
+ the conversion is not supported, it falls back to the onnx C API to convert
60
+ the model. This pass is in-place.
61
+
62
+ The pass is an no-op if the c-api fails.
63
+
64
+ Attributes:
65
+ target_version: The target ONNX opset version to convert the model to.
66
+ fallback: Whether to fallback to the onnx version converter if the
67
+ target version is not supported.
68
+ """
69
+
70
+ def __init__(self, target_version: int, fallback: bool) -> None:
71
+ super().__init__()
72
+ self.target_version = target_version
73
+ self.fallback = fallback
74
+
75
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
76
+ if model.functions:
77
+ raise ValueError(
78
+ "The model contains functions. The version conversion pass does not support "
79
+ "functions. Please use `common_passes.InlinePass` to inline the "
80
+ f"functions before applying this pass ({self.__class__.__name__})."
81
+ )
82
+ if "" in model.graph.opset_imports:
83
+ onnx_opset_version = model.graph.opset_imports[""]
84
+ if onnx_opset_version == self.target_version:
85
+ # No need to convert the version
86
+ return ir.passes.PassResult(model, False)
87
+
88
+ # When fallback is disabled, always use the onnxscript version converter;
89
+ # When fallback is enabled, use the onnxscript version converter
90
+ # if the target version is supported. Otherwise, use the onnx C API
91
+ # to convert the model.
92
+ if not self.fallback or _version_converter.version_supported(
93
+ model, self.target_version
94
+ ):
95
+ _version_converter.convert_version(
96
+ model,
97
+ target_version=self.target_version,
98
+ )
99
+ return ir.passes.PassResult(model, True)
100
+
101
+ if not self.fallback:
102
+ logger.warning(
103
+ "The model version conversion is not supported by the onnxscript version converter "
104
+ "and fallback is disabled. The model was not modified"
105
+ " (target version: %d). "
106
+ "Set fallback=True to enable fallback to the onnx c-api version converter.",
107
+ self.target_version,
108
+ )
109
+ return ir.passes.PassResult(model, False)
110
+ else:
111
+ logger.warning(
112
+ "The model version conversion is not supported by the onnxscript version converter "
113
+ "and fallback is enabled. The model will be converted using the onnx C API "
114
+ "(target version: %d).",
115
+ self.target_version,
116
+ )
117
+
118
+ # If the onnxscript version converter does not support the conversion,
119
+ # we can use the onnx C API to convert the model
120
+ def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
121
+ """Partial function to check the model."""
122
+ return onnx.version_converter.convert_version(
123
+ proto, target_version=self.target_version
124
+ )
125
+
126
+ try:
127
+ converted_proto = _c_api_utils.call_onnx_api(
128
+ func=_partial_convert_version, model=model
129
+ )
130
+ except Exception as e: # pylint: disable=broad-exception-caught
131
+ logger.warning(
132
+ "Failed to convert the model to the target version %d using the ONNX C API. "
133
+ "The model was not modified",
134
+ self.target_version,
135
+ exc_info=e,
136
+ )
137
+ return ir.passes.PassResult(model, False)
138
+
139
+ converted_model = ir.from_proto(converted_proto)
140
+
141
+ # Recover the initializers in the converted model
142
+ for input in converted_model.graph.inputs:
143
+ if input.name in model.graph.initializers:
144
+ input.const_value = model.graph.initializers[input.name].const_value
145
+ converted_model.graph.register_initializer(input)
146
+ user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
147
+ converted_model.graph.inputs.clear()
148
+ converted_model.graph.inputs.extend(user_inputs)
149
+
150
+ # Return the converted graph to the original model to keep the pass in-place
151
+ model.graph = converted_model.graph
152
+ return ir.passes.PassResult(model, True)
153
+
154
+
155
+ def convert_version(
156
+ model: ir.Model | onnx.ModelProto, target_version: int, fallback=None
157
+ ) -> None:
158
+ """Convert the model to the specified ONNX opset version.
159
+
160
+ Args:
161
+ model: The model to convert.
162
+ target_version: The target ONNX opset version.
163
+ fallback: Whether to fallback to the onnx version converter if the
164
+ target version is not supported. Default is False.
165
+ """
166
+ if isinstance(model, onnx.ModelProto):
167
+ model_proto = model
168
+ model = ir.from_proto(model)
169
+ else:
170
+ model_proto = None
171
+
172
+ assert isinstance(model, ir.Model)
173
+ ConvertVersionPass(target_version=target_version, fallback=fallback)(model)
174
+
175
+ if model_proto is not None:
176
+ # Update the model proto in-place
177
+ model_proto.graph.Clear()
178
+ del model_proto.functions[:]
179
+ model_proto.graph.CopyFrom(ir.to_proto(model.graph))
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.72 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_c_api_utils.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_version_converter.cpython-310.pyc ADDED
Binary file (9.96 kB). View file
 
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_c_api_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Utilities for interfacing with onnx C APIs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from typing import TYPE_CHECKING, Callable, TypeVar
9
+
10
+ from onnxscript import ir
11
+
12
+ if TYPE_CHECKING:
13
+ import onnx
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+ # Temporarily remove initializers larger than this size to keep model size down
18
+ # for the onnx.shape_inference call because it needs to serialize the model
19
+ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20
+ _R = TypeVar("_R")
21
+
22
+
23
+ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
24
+ """Call an ONNX C API function by temporarily removing initializers.
25
+
26
+ This is necessary because the ONNX C API does not support large models
27
+ with initializers that have large tensor values. The input model is left
28
+ unchanged no matter the call succeeds or not.
29
+
30
+ Args:
31
+ func: Partially applied function that takes a model proto and returns anything.
32
+ model: The IR model to pass to the API function.
33
+
34
+ Returns:
35
+ The resulting ModelProto that contains the result of the API call.
36
+ """
37
+
38
+ # Store the original initializer values so they can be restored
39
+ initializer_values = tuple(model.graph.initializers.values())
40
+ tensors = {v.name: v.const_value for v in initializer_values}
41
+ original_inputs_len = len(model.graph.inputs)
42
+
43
+ # Turn the initializers into inputs and clear the initializers
44
+ # to limit the model size
45
+ for initializer in initializer_values:
46
+ # Make sure the initializer has its shape/type set
47
+ assert initializer.const_value is not None
48
+ if initializer.shape is None:
49
+ initializer.shape = initializer.const_value.shape # type: ignore[assignment]
50
+ if initializer.dtype is None:
51
+ initializer.dtype = initializer.const_value.dtype
52
+ if initializer not in model.graph.inputs:
53
+ model.graph.inputs.append(initializer)
54
+ if initializer.const_value.size > _BIG_TENSOR_SIZE_LIMIT:
55
+ # Temporarily remove the initializer value to reduce model size
56
+ # for onnx.shape_inference
57
+ initializer.const_value = None
58
+ assert initializer.name is not None
59
+ model.graph.initializers.pop(initializer.name)
60
+
61
+ proto = ir.serde.serialize_model(model)
62
+
63
+ try:
64
+ # Call the ONNX C API function
65
+ result = func(proto)
66
+ finally:
67
+ # Restore the original initializer values so the model is unchanged
68
+ for initializer in initializer_values:
69
+ initializer.const_value = tensors[initializer.name]
70
+ model.graph.register_initializer(initializer)
71
+
72
+ # Restore the original inputs
73
+ inputs = model.graph.inputs[:original_inputs_len]
74
+ model.graph.inputs.clear()
75
+ model.graph.inputs.extend(inputs)
76
+
77
+ return result
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_version_converter.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Convert the model to the specified ONNX opset version."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+ import functools
9
+ import logging
10
+ from typing import Callable, Sequence, Union
11
+
12
+ import onnx_ir.convenience as ir_convenience
13
+
14
+ import onnxscript.ir._tape as _tape
15
+ from onnxscript import ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ SUPPORTED_MAX_ONNX_OPSET = 23
21
+ SUPPORTED_MIN_ONNX_OPSET = 18
22
+
23
+
24
+ def _get_onnx_opset_version(model: ir.Model) -> int | None:
25
+ """Get the ONNX opset version imported by the model."""
26
+ model_version1 = model.opset_imports.get("")
27
+ model_version2 = model.opset_imports.get("ai.onnx")
28
+ if model_version1 is not None and model_version2 is not None:
29
+ if model_version1 != model_version2:
30
+ raise ValueError(
31
+ f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
32
+ )
33
+ return model_version1 or model_version2
34
+
35
+
36
+ def _set_onnx_opset_version(model: ir.Model, version: int) -> None:
37
+ """Set the ONNX opset version imported by the model."""
38
+ if "ai.onnx" in model.opset_imports:
39
+ del model.opset_imports["ai.onnx"]
40
+ model.opset_imports[""] = version
41
+
42
+
43
+ class VersionConverterError(RuntimeError):
44
+ """Raised when an node's version cannot be upgraded/downgraded successfully."""
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class Replacement:
49
+ """A replacement for a node in the graph."""
50
+
51
+ new_outputs: Sequence[ir.Value]
52
+ new_nodes: Sequence[ir.Node]
53
+
54
+
55
+ # A version-adapter function takes a node, a RewriterContext and returns
56
+ # a Replacement for the node or None (if no replacement is needed).
57
+
58
+ RewriterContext = _tape.Builder
59
+ ReturnValue = Union[Sequence[ir.Value], ir.Value, None]
60
+ AdapterFunction = Callable[[ir.Node, RewriterContext], ReturnValue]
61
+
62
+
63
+ def version_supported(model: ir.Model, target_version: int) -> bool:
64
+ """Check if the target version is supported by the current version."""
65
+ if "" in model.graph.opset_imports:
66
+ current_version = model.graph.opset_imports[""]
67
+ else:
68
+ return True
69
+ return (
70
+ SUPPORTED_MIN_ONNX_OPSET
71
+ <= current_version
72
+ <= target_version
73
+ <= SUPPORTED_MAX_ONNX_OPSET
74
+ )
75
+
76
+
77
+ class AdapterRegistry:
78
+ """A class that maintains a registry of adapters for ops."""
79
+
80
+ def __init__(self):
81
+ self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {}
82
+
83
+ def lookup_adapters(
84
+ self,
85
+ domain: str,
86
+ opname: str,
87
+ original_version: int,
88
+ up_conversion: bool = True,
89
+ ) -> AdapterFunction | None:
90
+ adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion))
91
+ if adapter_func is not None:
92
+ return adapter_func
93
+ return None
94
+
95
+ def register(
96
+ self, opname: str, domain: str = "", node_version=None, up_conversion=True
97
+ ) -> Callable[[AdapterFunction], AdapterFunction]:
98
+ """Register an adapter based on the domain, operator type, node version and whether to upgrade/downgrade node version"""
99
+
100
+ def decorator(function: AdapterFunction) -> AdapterFunction:
101
+ @functools.wraps(function)
102
+ def wrapped_function(*args, **kwargs):
103
+ return function(*args, **kwargs)
104
+
105
+ self.op_adapters[(domain, opname, node_version, up_conversion)] = function
106
+ return wrapped_function
107
+
108
+ return decorator
109
+
110
+
111
+ registry: AdapterRegistry = AdapterRegistry()
112
+
113
+ register = registry.register
114
+
115
+
116
+ def _get_input(node: ir.Node, index: int) -> ir.Value | None:
117
+ if index < len(node.inputs):
118
+ return node.inputs[index]
119
+ return None
120
+
121
+
122
+ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None:
123
+ if name in node.attributes:
124
+ attr = node.attributes[name]
125
+ if not isinstance(attr, ir.Attr):
126
+ return None
127
+ attr_val = attr.value
128
+ if isinstance(attr_val, int):
129
+ return attr_val
130
+ # This is an invalid model: attribute has invalid/unexpected type.
131
+ # For now, we just return None. We could raise an error too.
132
+ return None
133
+ return default
134
+
135
+
136
+ def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> str | None:
137
+ if name in node.attributes:
138
+ attr = node.attributes[name]
139
+ if not isinstance(attr, ir.Attr):
140
+ return None
141
+ attr_val = attr.value
142
+ if isinstance(attr_val, str):
143
+ return attr_val
144
+ # This is an invalid model: attribute has invalid/unexpected type.
145
+ # For now, we just return None. We could raise an error too.
146
+ return None
147
+ return default
148
+
149
+
150
+ ## Op-specific adapters
151
+
152
+ # Opset 19 -> 20
153
+
154
+
155
+ @register("DFT", node_version=19, up_conversion=True)
156
+ def dft_19_20(node: ir.Node, op):
157
+ input = node.inputs[0]
158
+ inverse = _get_int_attribute(node, "inverse", 0)
159
+ onesided = _get_int_attribute(node, "onesided", 0)
160
+ axis = _get_int_attribute(node, "axis", None)
161
+ if axis is not None:
162
+ axis_value = op.Constant(value_int=axis)
163
+ return op.DFT(input, axis_value, inverse=inverse, onesided=onesided)
164
+ return None
165
+
166
+
167
+ @register("GridSample", node_version=19, up_conversion=True)
168
+ def gridsample_19_20(node: ir.Node, op):
169
+ x = node.inputs[0]
170
+ grid = node.inputs[1]
171
+ align_corners = _get_int_attribute(node, "align_corners", 0)
172
+ mode = _get_str_attribute(node, "mode", "linear")
173
+ padding_mode = _get_str_attribute(node, "padding_mode", "zeros")
174
+ if mode == "bilinear":
175
+ return op.GridSample(
176
+ x, grid, align_corners=align_corners, mode="linear", padding_mode=padding_mode
177
+ )
178
+ elif mode == "bicubic":
179
+ return op.GridSample(
180
+ x, grid, align_corners=align_corners, mode="cubic", padding_mode=padding_mode
181
+ )
182
+ return None
183
+
184
+
185
+ # Opset 20 -> 21
186
+
187
+
188
+ @register("GroupNormalization", node_version=20, up_conversion=True)
189
+ def groupnormalization_20_21(node: ir.Node, op):
190
+ x = _get_input(node, 0)
191
+ scale = _get_input(node, 1)
192
+ bias = _get_input(node, 2)
193
+ if x is None or scale is None or bias is None:
194
+ raise VersionConverterError(f"Missing input for {node}")
195
+
196
+ x_shape = x.shape
197
+ if x_shape is None:
198
+ raise VersionConverterError(f"Missing required shape for {x}")
199
+ num_channels = x_shape[1]
200
+ if not isinstance(num_channels, int):
201
+ return None
202
+
203
+ scale_shape = scale.shape
204
+ bias_shape = bias.shape
205
+ if scale_shape is None or bias_shape is None:
206
+ return None
207
+ if not isinstance(scale_shape[0], int) or not isinstance(bias_shape[0], int):
208
+ return None
209
+
210
+ num_groups = _get_int_attribute(node, "num_groups", None)
211
+ if num_groups is None:
212
+ raise VersionConverterError("Missing required attribute: num_groups")
213
+ if (
214
+ num_groups != num_channels
215
+ and num_groups == scale_shape[0]
216
+ and num_groups == bias_shape[0]
217
+ ):
218
+ reshape_1_sizes = op.Constant(value_ints=[-1, 1])
219
+ reshape_2_sizes = op.Constant(value_ints=[-1])
220
+ c_div = int(num_channels / num_groups)
221
+ expand_sizes = op.Constant(value_ints=[1, c_div])
222
+
223
+ # Modify scale input
224
+ scale_reshape_1 = op.Reshape(scale, reshape_1_sizes)
225
+ scale_expand = op.Expand(scale_reshape_1, expand_sizes)
226
+ scale_reshape_2 = op.Reshape(scale_expand, reshape_2_sizes)
227
+
228
+ # Modify bias input
229
+ bias_reshape_1 = op.Reshape(bias, reshape_1_sizes)
230
+ bias_expand = op.Expand(bias_reshape_1, expand_sizes)
231
+ bias_reshape_2 = op.Reshape(bias_expand, reshape_2_sizes)
232
+
233
+ return op.GroupNormalization(x, scale_reshape_2, bias_reshape_2, num_groups=num_groups)
234
+ return None
235
+
236
+
237
+ class _VersionConverter:
238
+ def __init__(self, target_version: int):
239
+ self._target_version = target_version
240
+
241
+ def process_node(
242
+ self, node: ir.Node, from_version: int, up_conversion: bool = True
243
+ ) -> Replacement | None:
244
+ assert node.domain == ""
245
+ adapter = registry.lookup_adapters(
246
+ node.domain, node.op_type, from_version, up_conversion
247
+ )
248
+ if adapter is None:
249
+ return None
250
+ context = RewriterContext()
251
+ output = adapter(node, context)
252
+ if output is not None:
253
+ if isinstance(output, ir.Value):
254
+ output = [output]
255
+ return Replacement(output, context.nodes)
256
+ return None
257
+
258
+ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
259
+ logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
260
+
261
+ ir_convenience.replace_nodes_and_values(
262
+ root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
263
+ )
264
+
265
+ def visit_attribute(self, attr: ir.Attr) -> None:
266
+ if attr.is_ref():
267
+ return
268
+ if attr.type == ir.AttributeType.GRAPH:
269
+ self.visit_graph(attr.as_graph())
270
+ elif attr.type == ir.AttributeType.GRAPHS:
271
+ for graph in attr.as_graphs():
272
+ self.visit_graph(graph)
273
+
274
+ def visit_node(
275
+ self,
276
+ node: ir.Node,
277
+ root: ir.Graph | ir.Function,
278
+ from_version: int,
279
+ up_conversion: bool = True,
280
+ ) -> None:
281
+ if up_conversion:
282
+ to_version = from_version + 1
283
+ else:
284
+ to_version = from_version - 1
285
+ replacement = self.process_node(node, from_version, up_conversion)
286
+ if replacement is None:
287
+ # No change. Process attributes.
288
+ for attr in node.attributes.values():
289
+ self.visit_attribute(attr)
290
+ node.version = to_version
291
+ else:
292
+ for new_node in replacement.new_nodes:
293
+ # TODO: control-flow
294
+ new_node.version = to_version
295
+ self.replace_node(node, replacement, root)
296
+
297
+ def visit_graph(self, graph: ir.Graph) -> None:
298
+ for node in graph:
299
+ if node.domain != "":
300
+ continue
301
+ node_version = node.version or self._default_onnx_opset
302
+ if node_version is None:
303
+ raise VersionConverterError(f"Node {node} has no version.")
304
+ # Iterate each node from current node version -> target version
305
+ # and updating node based on the correct adapter
306
+ # Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
307
+ # TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
308
+ if self._target_version < node_version:
309
+ raise VersionConverterError(
310
+ f"Target opset: {self._target_version} less than node version: {node.version}, "
311
+ "downstream version conversion not currently handled."
312
+ )
313
+ for from_version in range(node_version, self._target_version):
314
+ try:
315
+ self.visit_node(node, graph, from_version, up_conversion=True)
316
+ except VersionConverterError as e:
317
+ logger.warning(
318
+ "Skipping version conversion for node %s due to exception: %s",
319
+ node.op_type,
320
+ e,
321
+ )
322
+
323
+ def visit_model(self, model: ir.Model) -> None:
324
+ self._default_onnx_opset = _get_onnx_opset_version(model)
325
+ self.visit_graph(model.graph)
326
+ _set_onnx_opset_version(model, self._target_version)
327
+
328
+
329
+ def convert_version(model: ir.Model, target_version: int) -> None:
330
+ """Convert the model to the specified ONNX opset version."""
331
+ if (target_version > SUPPORTED_MAX_ONNX_OPSET) or (
332
+ target_version < SUPPORTED_MIN_ONNX_OPSET
333
+ ):
334
+ raise ValueError(
335
+ f"Target opset version {target_version} is not supported. "
336
+ f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}."
337
+ )
338
+ version_converter = _VersionConverter(target_version=target_version)
339
+ version_converter.visit_model(model)