Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/onnxscript/_framework_apis/torch_2_5.py +117 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset10.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset12.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/basic_rules.py +321 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/broadcast_to_matmul.py +178 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/cast_constant_of_shape.py +46 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/collapse_slices.py +107 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/__init__.py +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/evaluation_utils.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/timing_utils.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/utils.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/timing_utils.py +33 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/utils/utils.py +84 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__init__.py +179 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_c_api_utils.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_version_converter.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_c_api_utils.py +77 -0
- pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_version_converter.py +339 -0
pythonProject/.venv/Lib/site-packages/onnxscript/_framework_apis/torch_2_5.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
"""Stable APIs for PyTorch 2.5."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"check_model",
|
| 9 |
+
"convert_version",
|
| 10 |
+
"get_torchlib_ops",
|
| 11 |
+
"optimize",
|
| 12 |
+
"save_model_with_external_data",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
import dataclasses
|
| 16 |
+
import os
|
| 17 |
+
import pathlib
|
| 18 |
+
from typing import Callable
|
| 19 |
+
|
| 20 |
+
from onnxscript import ir, optimizer, version_converter
|
| 21 |
+
from onnxscript.function_libs.torch_lib import registration
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclasses.dataclass(frozen=True)
|
| 25 |
+
class _OnnxFunctionMeta:
|
| 26 |
+
"""A wrapper of onnx-script function with additional metadata.
|
| 27 |
+
|
| 28 |
+
qualified_name: The qualified name of the aten operator.
|
| 29 |
+
function: The onnx-script function.
|
| 30 |
+
domain: The domain of the function.
|
| 31 |
+
name: The name of the function.
|
| 32 |
+
is_complex: Whether the function is a complex function.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
qualified_name: str
|
| 36 |
+
function: Callable
|
| 37 |
+
domain: str
|
| 38 |
+
name: str
|
| 39 |
+
is_complex: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def optimize(model: ir.Model) -> ir.Model:
|
| 43 |
+
"""Optimize the model."""
|
| 44 |
+
# Internal flag. Will go away.
|
| 45 |
+
enabled = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1"
|
| 46 |
+
if enabled:
|
| 47 |
+
optimizer.optimize_ir(model)
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
|
| 52 |
+
"""Convert the model to the specified ONNX opset version."""
|
| 53 |
+
# Internal flag. Will go away.
|
| 54 |
+
enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1"
|
| 55 |
+
if enabled:
|
| 56 |
+
version_converter.convert_version(model, target_version)
|
| 57 |
+
return model
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def check_model(model: ir.Model) -> None:
|
| 61 |
+
"""Check the model."""
|
| 62 |
+
|
| 63 |
+
del model # Unused yet
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
|
| 67 |
+
"""Save the model with external data. The model is unchanged after saving."""
|
| 68 |
+
|
| 69 |
+
# TODO(#1835): Decide if we want to externalize large attributes as well
|
| 70 |
+
for value in model.graph.initializers.values():
|
| 71 |
+
if value.const_value is None:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
"The model contains uninitialized initializer values. "
|
| 74 |
+
"Please make sure all initializer values are initialized."
|
| 75 |
+
)
|
| 76 |
+
destination_path = pathlib.Path(model_path)
|
| 77 |
+
data_path = f"{destination_path.name}.data"
|
| 78 |
+
|
| 79 |
+
ir.save(model, model_path, external_data=data_path)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
|
| 83 |
+
# Trigger op registration
|
| 84 |
+
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
|
| 85 |
+
ops,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
del ops # Unused
|
| 89 |
+
|
| 90 |
+
torchlib_registry = registration.default_registry
|
| 91 |
+
function_metas = []
|
| 92 |
+
|
| 93 |
+
for qualified_name, aten_overloads_func in torchlib_registry.items():
|
| 94 |
+
if qualified_name.startswith("internal::"):
|
| 95 |
+
# Skip the custom defined internal functions
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
for overload_func in aten_overloads_func.overloads:
|
| 99 |
+
function_meta = _OnnxFunctionMeta(
|
| 100 |
+
qualified_name=qualified_name,
|
| 101 |
+
function=overload_func,
|
| 102 |
+
domain=overload_func.function_ir.domain,
|
| 103 |
+
name=overload_func.name,
|
| 104 |
+
is_complex=False,
|
| 105 |
+
)
|
| 106 |
+
function_metas.append(function_meta)
|
| 107 |
+
for complex_func in aten_overloads_func.complex:
|
| 108 |
+
function_meta = _OnnxFunctionMeta(
|
| 109 |
+
qualified_name=qualified_name,
|
| 110 |
+
function=complex_func,
|
| 111 |
+
domain=complex_func.function_ir.domain,
|
| 112 |
+
name=complex_func.name,
|
| 113 |
+
is_complex=True,
|
| 114 |
+
)
|
| 115 |
+
function_metas.append(function_meta)
|
| 116 |
+
|
| 117 |
+
return function_metas
|
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset10.cpython-310.pyc
ADDED
|
Binary file (49.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/onnx_opset/_impl/__pycache__/opset12.cpython-310.pyc
ADDED
|
Binary file (40 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/basic_rules.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
"""Basic rewrite rules for general optimization patterns.
|
| 4 |
+
|
| 5 |
+
This module contains fundamental optimization rules that are generally applicable
|
| 6 |
+
to most ONNX models, including cast elimination, transpose simplification,
|
| 7 |
+
shape operation fusion, and other common patterns.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import ClassVar, Sequence
|
| 13 |
+
|
| 14 |
+
from onnxscript import ir
|
| 15 |
+
from onnxscript.rewriter import _ir_utils as ir_utils
|
| 16 |
+
from onnxscript.rewriter._basics import MatchResult
|
| 17 |
+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SqueezeReshape(RewriteRuleClassBase):
|
| 21 |
+
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
|
| 22 |
+
|
| 23 |
+
This pattern arises from the translation of pytorch symints.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__("SqueezeReshape1d", remove_nodes=False)
|
| 28 |
+
|
| 29 |
+
def pattern(self, op, x):
|
| 30 |
+
return op.Reshape(op.Squeeze(x), [-1])
|
| 31 |
+
|
| 32 |
+
def rewrite(self, op, x: ir.Value):
|
| 33 |
+
return op.Identity(x)
|
| 34 |
+
|
| 35 |
+
def check(self, context, x) -> MatchResult:
|
| 36 |
+
del context # Unused
|
| 37 |
+
check_result = MatchResult()
|
| 38 |
+
if not ir_utils.has_rank(x, 1):
|
| 39 |
+
return check_result.fail("Input is not 1D")
|
| 40 |
+
return check_result
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CastIdentity(RewriteRuleClassBase):
|
| 44 |
+
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
|
| 45 |
+
|
| 46 |
+
def pattern(self, op, x, to):
|
| 47 |
+
return op.Cast(x, to=to)
|
| 48 |
+
|
| 49 |
+
def rewrite(self, op, x: ir.Value, to: ir.Attr):
|
| 50 |
+
return op.Identity(x)
|
| 51 |
+
|
| 52 |
+
def check(self, context, x, to) -> MatchResult:
|
| 53 |
+
check_result = MatchResult()
|
| 54 |
+
if x.dtype != to.as_int():
|
| 55 |
+
return check_result.fail("Input and output types are not the same")
|
| 56 |
+
return check_result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CastCast(RewriteRuleClassBase):
|
| 60 |
+
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
|
| 61 |
+
|
| 62 |
+
# Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
|
| 63 |
+
# This rule is not valid for all combinations of types: e.g.,
|
| 64 |
+
# it is not valid for float32 => float16 => float32 or float32 => int32 => string.
|
| 65 |
+
# TODO: fill out the list of allowed combinations: the following is just a couple
|
| 66 |
+
# that shows up in practice where it is valid
|
| 67 |
+
_allowed_type2_type3: ClassVar = frozenset(
|
| 68 |
+
{
|
| 69 |
+
(ir.DataType.FLOAT, ir.DataType.FLOAT16),
|
| 70 |
+
(ir.DataType.FLOAT, ir.DataType.BFLOAT16),
|
| 71 |
+
}
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def pattern(self, op, x, to, to_ignored):
|
| 75 |
+
return op.Cast(op.Cast(x, to=to_ignored), to=to)
|
| 76 |
+
|
| 77 |
+
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult:
|
| 78 |
+
check_result = MatchResult()
|
| 79 |
+
type2 = to_ignored.as_int()
|
| 80 |
+
type3 = to.as_int()
|
| 81 |
+
if (type2, type3) not in self._allowed_type2_type3:
|
| 82 |
+
return check_result.fail(
|
| 83 |
+
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
|
| 84 |
+
f"Cast-Cast rule may be incomplete for this combination."
|
| 85 |
+
)
|
| 86 |
+
return check_result
|
| 87 |
+
|
| 88 |
+
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
|
| 89 |
+
return op.Cast(x, to=to)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ExpandIdentity(RewriteRuleClassBase):
|
| 93 |
+
"""Replaces ``Expand(..., shape)`` by ``Identity`` if possible."""
|
| 94 |
+
|
| 95 |
+
def pattern(self, op, x, shape):
|
| 96 |
+
return op.Expand(x, shape)
|
| 97 |
+
|
| 98 |
+
def rewrite(self, op, x: ir.Value, shape: ir.Value):
|
| 99 |
+
return op.Identity(x)
|
| 100 |
+
|
| 101 |
+
def check(self, context, x, shape) -> MatchResult:
|
| 102 |
+
check_result = MatchResult()
|
| 103 |
+
if shape.const_value is None:
|
| 104 |
+
# Shape is not a constant and cannot be guessed.
|
| 105 |
+
return check_result.fail("Shape is not a constant and cannot be guessed.")
|
| 106 |
+
if (x_shape := x.shape) is None:
|
| 107 |
+
# We don't know the shape of the input
|
| 108 |
+
return check_result.fail("Input shape is not known.")
|
| 109 |
+
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
|
| 110 |
+
return check_result.fail(
|
| 111 |
+
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
|
| 112 |
+
)
|
| 113 |
+
return check_result
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ReshapeReshape(RewriteRuleClassBase):
|
| 117 |
+
"""Replaces ``Reshape(Reshape(X, ...), shape)`` by ``Reshape(X, shape)``.
|
| 118 |
+
The pattern matches only if second reshape reshapes into a shape
|
| 119 |
+
with positive values.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def pattern(self, op, x, shape_ignored, shape):
|
| 123 |
+
return op.Reshape(op.Reshape(x, shape_ignored), shape)
|
| 124 |
+
|
| 125 |
+
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
|
| 126 |
+
return op.Reshape(x, shape)
|
| 127 |
+
|
| 128 |
+
def check(self, context, x, shape_ignored, shape) -> MatchResult:
|
| 129 |
+
check_result = MatchResult()
|
| 130 |
+
if shape_ignored.const_value is None:
|
| 131 |
+
return check_result.fail("Shape ignored is not a constant.")
|
| 132 |
+
if shape.const_value is None:
|
| 133 |
+
return check_result.fail("Shape is not a constant.")
|
| 134 |
+
if shape.const_value.numpy().min() <= 0:
|
| 135 |
+
return check_result.fail("Shape has non-positive values.")
|
| 136 |
+
return check_result
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SlicesSplit(RewriteRuleClassBase):
|
| 140 |
+
"""Replaces ``Slice(x, ...), Slice(x, ...)``
|
| 141 |
+
by ``Split(x, ...)`` if possible.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def pattern(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
|
| 145 |
+
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
|
| 146 |
+
|
| 147 |
+
def check(self, context, x, begin0, end0, axes0, begin1, end1, axes1) -> MatchResult:
|
| 148 |
+
check_result = MatchResult()
|
| 149 |
+
if (
|
| 150 |
+
axes0.const_value is None
|
| 151 |
+
or axes1.const_value is None
|
| 152 |
+
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
|
| 153 |
+
):
|
| 154 |
+
return check_result.fail("Axes are not equal or not constant.")
|
| 155 |
+
axes = axes0.const_value.numpy().tolist()
|
| 156 |
+
if len(axes) != 1:
|
| 157 |
+
return check_result.fail("Axes has more than one dimension.")
|
| 158 |
+
if x.shape:
|
| 159 |
+
rk = len(x.shape)
|
| 160 |
+
else:
|
| 161 |
+
rk = x.rank
|
| 162 |
+
if axes[0] != -1 and axes[0] != rk - 1:
|
| 163 |
+
return check_result.fail("Axes is not -1 or last dimension.")
|
| 164 |
+
if (
|
| 165 |
+
begin0.const_value is None
|
| 166 |
+
or end0.const_value is None
|
| 167 |
+
or begin1.const_value is None
|
| 168 |
+
or end1.const_value is None
|
| 169 |
+
):
|
| 170 |
+
return check_result.fail("Begin or end are not constant values.")
|
| 171 |
+
if begin0.const_value.numpy().tolist() != [0]:
|
| 172 |
+
return check_result.fail("First begin value is not 0.")
|
| 173 |
+
e0, b1, e1 = (
|
| 174 |
+
end0.const_value.numpy().tolist(),
|
| 175 |
+
begin1.const_value.numpy().tolist(),
|
| 176 |
+
end1.const_value.numpy().tolist(),
|
| 177 |
+
)
|
| 178 |
+
if e0[0] != b1[0]:
|
| 179 |
+
return check_result.fail("End0 is not equal to Begin1.")
|
| 180 |
+
shape = x.shape
|
| 181 |
+
if shape is None:
|
| 182 |
+
return check_result.fail("Shape is not known.")
|
| 183 |
+
last_dim = shape[-1]
|
| 184 |
+
if not isinstance(last_dim, int):
|
| 185 |
+
return check_result.fail("Last dimension is not known.")
|
| 186 |
+
if last_dim != e1[0]:
|
| 187 |
+
return check_result.fail("Last dimension is not equal to End1.")
|
| 188 |
+
if last_dim // 2 != b1[0]:
|
| 189 |
+
return check_result.fail("Last dimension is not equal to Begin1.")
|
| 190 |
+
return check_result
|
| 191 |
+
|
| 192 |
+
def rewrite(self, op, x, begin0, end0, axes0, begin1, end1, axes1):
|
| 193 |
+
return op.Split(x, num_outputs=2, axis=-1, _outputs=2)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class TransposeIdentity(RewriteRuleClassBase):
|
| 197 |
+
"""Replaces ``Transpose(. perm=perm)``
|
| 198 |
+
when the permutation is identity.
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def pattern(self, op, x, perm):
|
| 202 |
+
return op.Transpose(x, perm=perm)
|
| 203 |
+
|
| 204 |
+
def check(self, context, x: ir.Value, perm: ir.Attr) -> MatchResult:
|
| 205 |
+
check_result = MatchResult()
|
| 206 |
+
if perm.is_ref():
|
| 207 |
+
return check_result.fail("Permutation is a reference attribute.")
|
| 208 |
+
if perm.type == ir.AttributeType.INTS:
|
| 209 |
+
perm_ints = tuple(perm.as_ints())
|
| 210 |
+
if perm_ints == tuple(range(len(perm_ints))):
|
| 211 |
+
return check_result
|
| 212 |
+
return check_result.fail("Permutation is not identity.")
|
| 213 |
+
|
| 214 |
+
def rewrite(self, op, x: ir.Value, perm: ir.Attr):
|
| 215 |
+
return op.Identity(x)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class TransposeTranspose(RewriteRuleClassBase):
|
| 219 |
+
"""Replaces ``Transpose(Transpose(., perm=perm1), perm=perm2)``
|
| 220 |
+
when both permutations are inverse.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def pattern(self, op, x, perm1, perm2):
|
| 224 |
+
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
|
| 225 |
+
|
| 226 |
+
def check(self, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> MatchResult:
|
| 227 |
+
check_result = MatchResult()
|
| 228 |
+
if perm1.is_ref() or perm2.is_ref():
|
| 229 |
+
return check_result.fail("Permutation is a reference attribute.")
|
| 230 |
+
return check_result
|
| 231 |
+
|
| 232 |
+
def _apply_transpose(self, perm: Sequence[int], on: list[int]) -> list[int]:
|
| 233 |
+
assert len(perm) == len(on), "length mismatch"
|
| 234 |
+
res = [-1 for i in on]
|
| 235 |
+
for i, p in enumerate(perm):
|
| 236 |
+
res[i] = on[p]
|
| 237 |
+
return res
|
| 238 |
+
|
| 239 |
+
def _apply_transposes(
|
| 240 |
+
self, perms: list[Sequence[int]], on: list[int] | None = None
|
| 241 |
+
) -> list[int]:
|
| 242 |
+
if on is None:
|
| 243 |
+
on = list(range(len(perms[0])))
|
| 244 |
+
for p in perms:
|
| 245 |
+
on = self._apply_transpose(p, on)
|
| 246 |
+
return on
|
| 247 |
+
|
| 248 |
+
def rewrite(self, op, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr):
|
| 249 |
+
first = list(range(len(perm1.as_ints())))
|
| 250 |
+
last = self._apply_transposes([perm1.as_ints(), perm2.as_ints()])
|
| 251 |
+
if first == last:
|
| 252 |
+
return op.Identity(x)
|
| 253 |
+
return op.Transpose(x, perm=last)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class UnsqueezeUnsqueeze(RewriteRuleClassBase):
|
| 257 |
+
"""Replaces ``Unsqueeze(Unsqueeze(., axes1), axes2)`` with one Unsqueeze."""
|
| 258 |
+
|
| 259 |
+
def pattern(self, op, x, axes1, axes2):
|
| 260 |
+
return op.Unsqueeze(op.Unsqueeze(x, axes1), axes2)
|
| 261 |
+
|
| 262 |
+
def rewrite(self, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
|
| 263 |
+
v1 = ir_utils.get_singleton_value(axes1)
|
| 264 |
+
v2 = ir_utils.get_singleton_value(axes2)
|
| 265 |
+
axes = [v1, v2] if v1 < v2 else [v2, v1 + 1]
|
| 266 |
+
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
|
| 267 |
+
|
| 268 |
+
def check(self, context, x, axes1, axes2) -> MatchResult:
|
| 269 |
+
check_result = MatchResult()
|
| 270 |
+
del context # Unused
|
| 271 |
+
del x # Unused
|
| 272 |
+
# Currently restricted to single element positive axis
|
| 273 |
+
v1 = ir_utils.get_singleton_value(axes1)
|
| 274 |
+
v2 = ir_utils.get_singleton_value(axes2)
|
| 275 |
+
if v1 is None or v2 is None:
|
| 276 |
+
return check_result.fail("Axes are not constant.")
|
| 277 |
+
if (v1 < 0) or (v2 < 0):
|
| 278 |
+
return check_result.fail("Axes are negative.")
|
| 279 |
+
return check_result
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# Create rule instances
|
| 283 |
+
cast_cast_rule = CastCast.rule()
|
| 284 |
+
cast_identity_rule = CastIdentity.rule()
|
| 285 |
+
expand_identity_rule = ExpandIdentity.rule()
|
| 286 |
+
reshape_reshape_rule = ReshapeReshape.rule()
|
| 287 |
+
slice_split_rule = SlicesSplit.rule()
|
| 288 |
+
transpose_identity_rule = TransposeIdentity.rule()
|
| 289 |
+
transpose_transpose_rule = TransposeTranspose.rule()
|
| 290 |
+
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
|
| 291 |
+
squeeze_reshape_1d_rule = SqueezeReshape.rule()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def basic_optimization_rules() -> RewriteRuleSet:
|
| 295 |
+
"""Returns a set of basic optimization rules.
|
| 296 |
+
|
| 297 |
+
These rules perform fundamental optimizations such as:
|
| 298 |
+
- Eliminating redundant cast operations
|
| 299 |
+
- Simplifying consecutive operations of the same type
|
| 300 |
+
- Removing identity operations
|
| 301 |
+
- Optimizing shape manipulation operations
|
| 302 |
+
|
| 303 |
+
These rules are generally safe to apply as a first optimization pass
|
| 304 |
+
before other more specialized optimizations.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
RewriteRuleSet: A collection of basic optimization rules
|
| 308 |
+
"""
|
| 309 |
+
return RewriteRuleSet(
|
| 310 |
+
[
|
| 311 |
+
cast_cast_rule,
|
| 312 |
+
cast_identity_rule,
|
| 313 |
+
expand_identity_rule,
|
| 314 |
+
reshape_reshape_rule,
|
| 315 |
+
slice_split_rule,
|
| 316 |
+
transpose_identity_rule,
|
| 317 |
+
transpose_transpose_rule,
|
| 318 |
+
unsqueeze_unsqueeze_rule,
|
| 319 |
+
squeeze_reshape_1d_rule,
|
| 320 |
+
]
|
| 321 |
+
)
|
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/broadcast_to_matmul.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from onnxscript import ir
|
| 8 |
+
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def check_if_not_need_reshape(
|
| 14 |
+
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
|
| 15 |
+
) -> bool:
|
| 16 |
+
"""Condition to check if we need to replace the pattern.
|
| 17 |
+
|
| 18 |
+
If matmul broadcasting is enough, then we don't need the reshapes.
|
| 19 |
+
|
| 20 |
+
To validate this, we need to check the following:
|
| 21 |
+
1. Input shapes check: input_a and input_b should be broadcastable
|
| 22 |
+
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
|
| 23 |
+
|
| 24 |
+
If the above are true, then we don't need the reshapes.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
True if we need to replace the pattern, False otherwise.
|
| 28 |
+
"""
|
| 29 |
+
del context # Reserved for future extensions
|
| 30 |
+
|
| 31 |
+
input_a_shape = input_a.shape
|
| 32 |
+
input_b_shape = input_b.shape
|
| 33 |
+
shape_c_tensor = shape_c.const_value
|
| 34 |
+
if shape_c_tensor is None:
|
| 35 |
+
logger.info("The value 'shape_c' is not statically known.")
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
if len(shape_c_tensor.shape) != 1:
|
| 39 |
+
logger.info(
|
| 40 |
+
"Unexpected final shape. The shape of 'shape' value is %s",
|
| 41 |
+
shape_c_tensor.shape,
|
| 42 |
+
)
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
|
| 46 |
+
# information. So, we need to check if the shape is None and return False.
|
| 47 |
+
if input_a_shape is None or input_b_shape is None:
|
| 48 |
+
logger.info("Shape information is not available for the inputs and outputs.")
|
| 49 |
+
return False
|
| 50 |
+
if any(isinstance(dim, ir.SymbolicDim) for dim in input_a_shape):
|
| 51 |
+
logger.info("Symbolic dimensions are not yet supported.")
|
| 52 |
+
return False
|
| 53 |
+
if any(isinstance(dim, ir.SymbolicDim) for dim in input_b_shape):
|
| 54 |
+
logger.info("Symbolic dimensions are not yet supported.")
|
| 55 |
+
return False
|
| 56 |
+
input_a_shape = input_a_shape.numpy() # type: ignore[assignment]
|
| 57 |
+
input_b_shape = input_b_shape.numpy() # type: ignore[assignment]
|
| 58 |
+
shape_c = shape_c_tensor.numpy().tolist() # type: ignore[assignment]
|
| 59 |
+
|
| 60 |
+
a_rank = len(input_a_shape)
|
| 61 |
+
b_rank = len(input_b_shape)
|
| 62 |
+
|
| 63 |
+
# 1. Check if input shapes are broadcastable
|
| 64 |
+
# 1.a. If the first input is 1-D, check whether
|
| 65 |
+
# the dim matches the last second dim of the second input.
|
| 66 |
+
mimic_matmul_broadcast_behavior_a = False
|
| 67 |
+
mimic_matmul_broadcast_behavior_b = False
|
| 68 |
+
if a_rank < 2:
|
| 69 |
+
if b_rank < 2:
|
| 70 |
+
logger.info("Optimization of dot product is not supported yet.")
|
| 71 |
+
return False
|
| 72 |
+
if input_a_shape[-1] != input_b_shape[-2]:
|
| 73 |
+
logger.info("Original shape is not MatMul compatible.")
|
| 74 |
+
return False
|
| 75 |
+
else:
|
| 76 |
+
input_a_shape = [1, *input_a_shape] # type: ignore[assignment]
|
| 77 |
+
a_rank = len(input_a_shape)
|
| 78 |
+
mimic_matmul_broadcast_behavior_a = True
|
| 79 |
+
# 1.b. If the second input is 1-D, check whether
|
| 80 |
+
# the dim matches the last dim of the first input.
|
| 81 |
+
if b_rank < 2:
|
| 82 |
+
if input_b_shape[-1] != input_a_shape[-1]:
|
| 83 |
+
logger.info("Original shape is not MatMul compatible.")
|
| 84 |
+
return False
|
| 85 |
+
else:
|
| 86 |
+
input_b_shape = [*input_b_shape, 1] # type: ignore[assignment]
|
| 87 |
+
b_rank = len(input_b_shape)
|
| 88 |
+
mimic_matmul_broadcast_behavior_b = True
|
| 89 |
+
# 1.c. If both inputs are at least 2-D, check whether
|
| 90 |
+
# the last dimension of the first input matches the second
|
| 91 |
+
# last dimension of the second input, and shape[:-2] are
|
| 92 |
+
# broadcastable.
|
| 93 |
+
input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
|
| 94 |
+
input_b_shape_except_last_dim = input_b_shape[:-1]
|
| 95 |
+
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
|
| 96 |
+
for idx, (dim_from_a, dim_from_b) in enumerate(
|
| 97 |
+
zip(
|
| 98 |
+
reversed(input_a_shape_except_second_last_dim),
|
| 99 |
+
reversed(input_b_shape_except_last_dim),
|
| 100 |
+
)
|
| 101 |
+
):
|
| 102 |
+
if dim_from_a not in {1, dim_from_b}:
|
| 103 |
+
logger.info("Original shape is not broadcastable.")
|
| 104 |
+
return False
|
| 105 |
+
elif idx > 0:
|
| 106 |
+
broadcast_matmul_output_shape = [
|
| 107 |
+
max(dim_from_a, dim_from_b), # type: ignore[type-var]
|
| 108 |
+
*broadcast_matmul_output_shape,
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
|
| 112 |
+
# Prepend the broadcast_matmul_output_shape with the longer shape of input
|
| 113 |
+
if a_rank > b_rank:
|
| 114 |
+
longer_shape = input_a_shape
|
| 115 |
+
shorter_shape = input_b_shape
|
| 116 |
+
else:
|
| 117 |
+
longer_shape = input_b_shape
|
| 118 |
+
shorter_shape = input_a_shape
|
| 119 |
+
broadcast_matmul_output_shape = [
|
| 120 |
+
*longer_shape[: -len(shorter_shape)],
|
| 121 |
+
*broadcast_matmul_output_shape,
|
| 122 |
+
]
|
| 123 |
+
if mimic_matmul_broadcast_behavior_b and b_rank == 2 and input_b_shape[-1] == 1:
|
| 124 |
+
# If input_b is expanded to 2-D, then we need to remove the last dimension
|
| 125 |
+
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
|
| 126 |
+
if mimic_matmul_broadcast_behavior_a and a_rank == 2 and input_a_shape[0] == 1:
|
| 127 |
+
# If input_a is expanded to 2-D, then we need to remove the first dimension
|
| 128 |
+
# of input_a, which would be the -2nd dimension of the output shape.
|
| 129 |
+
broadcast_matmul_output_shape.pop(-2)
|
| 130 |
+
if shape_c != broadcast_matmul_output_shape:
|
| 131 |
+
logger.info(
|
| 132 |
+
"Final output shape is not the same. Expected %s vs actual %s",
|
| 133 |
+
shape_c,
|
| 134 |
+
broadcast_matmul_output_shape,
|
| 135 |
+
)
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
|
| 142 |
+
# TODO: Modified from `value_ints` to `value` to match pattern in benchmark models.
|
| 143 |
+
# This implementation misses pattern of Constants with `value_ints` attribute.
|
| 144 |
+
# See more at https://github.com/microsoft/onnx-rewriter/issues/191.
|
| 145 |
+
# A better solution is to improve pattern matching and avoid depending on writing
|
| 146 |
+
# Constants in pattern. See https://github.com/microsoft/onnx-rewriter/issues/192.
|
| 147 |
+
reshape_a = op.Reshape(input_a, shape_a)
|
| 148 |
+
reshape_b = op.Reshape(input_b, shape_b)
|
| 149 |
+
matmul = op.MatMul(reshape_a, reshape_b)
|
| 150 |
+
return op.Reshape(matmul, shape_c)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _matmul(op, input_a, input_b, **_):
|
| 154 |
+
return op.MatMul(input_a, input_b)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c):
|
| 158 |
+
reshape_a = op.Reshape(input_a, shape_a)
|
| 159 |
+
matmul = op.MatMul(reshape_a, input_b)
|
| 160 |
+
return op.Reshape(matmul, shape_c)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Register the rewrite rules
|
| 164 |
+
two_reshapes_matmul_reshape_rule = RewriteRule(
|
| 165 |
+
_two_reshapes_matmul_reshape_pattern,
|
| 166 |
+
_matmul,
|
| 167 |
+
check_if_not_need_reshape,
|
| 168 |
+
)
|
| 169 |
+
one_reshape_matmul_reshape_rule = RewriteRule(
|
| 170 |
+
_one_reshape_matmul_reshape_pattern,
|
| 171 |
+
_matmul,
|
| 172 |
+
# We can use the same check_if_not_need_reshape function for both the rules,
|
| 173 |
+
# as one_reshape_matmul_reshape_pattern is a subset of _two_reshapes_matmul_reshape_pattern.
|
| 174 |
+
check_if_not_need_reshape,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# NOTE: The order of the rules is important. Larger pattern should be checked first.
|
| 178 |
+
rules = RewriteRuleSet([two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule])
|
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/cast_constant_of_shape.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from onnxscript import ir
|
| 8 |
+
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def cast_constant_of_shape(op, shape, scalar, dtype):
|
| 14 |
+
constant = op.ConstantOfShape(shape, value=scalar)
|
| 15 |
+
return op.Cast(constant, to=dtype)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_):
|
| 19 |
+
# Cast scalar (a TensorProto attribute) to the specified dtype
|
| 20 |
+
scalar_value = scalar.value.numpy().item()
|
| 21 |
+
cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int()))
|
| 22 |
+
return op.ConstantOfShape(shape, value=cast_value)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def cast_constant_of_shape_without_value(op, shape, dtype):
|
| 26 |
+
constant = op.ConstantOfShape(shape)
|
| 27 |
+
return op.Cast(constant, to=dtype)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_):
|
| 31 |
+
zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int()))
|
| 32 |
+
return op.ConstantOfShape(shape, value=zero)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
cast_constant_of_shape_rule = RewriteRule(cast_constant_of_shape, fused_cast_constant_of_shape)
|
| 36 |
+
|
| 37 |
+
cast_constant_of_shape_without_value_rule = RewriteRule(
|
| 38 |
+
cast_constant_of_shape_without_value, fused_cast_constant_of_shape_without_value
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
rules = RewriteRuleSet(
|
| 42 |
+
[
|
| 43 |
+
cast_constant_of_shape_rule,
|
| 44 |
+
cast_constant_of_shape_without_value_rule,
|
| 45 |
+
]
|
| 46 |
+
)
|
pythonProject/.venv/Lib/site-packages/onnxscript/rewriter/collapse_slices.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from onnxscript import ir
|
| 8 |
+
from onnxscript.rewriter._ir_utils import is_singleton_value
|
| 9 |
+
from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
_INT64_MAX = 9223372036854775807
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _check_if_redundant_slice(
|
| 16 |
+
context,
|
| 17 |
+
data: ir.Value,
|
| 18 |
+
starts: ir.Value,
|
| 19 |
+
ends: ir.Value,
|
| 20 |
+
axes: ir.Value,
|
| 21 |
+
steps: ir.Value,
|
| 22 |
+
**_,
|
| 23 |
+
) -> bool:
|
| 24 |
+
"""If the starts is 0, and the ends is equal to or grater than the shape of the specified axis, then the slice is redundant."""
|
| 25 |
+
del context # Reserved for future extensions
|
| 26 |
+
|
| 27 |
+
starts_const = starts.const_value
|
| 28 |
+
ends_const = ends.const_value
|
| 29 |
+
axes_const = axes.const_value
|
| 30 |
+
steps_const = steps.const_value
|
| 31 |
+
|
| 32 |
+
if starts_const is None or ends_const is None or axes_const is None or steps_const is None:
|
| 33 |
+
logger.info("The value 'start', 'end', 'axis', 'step' is not statically known.")
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
# Check if the values are scalar
|
| 37 |
+
if starts_const.numpy().size != 1: # type: ignore[union-attr]
|
| 38 |
+
logger.info("The value 'start' is not a scalar.")
|
| 39 |
+
return False
|
| 40 |
+
if ends_const.numpy().size != 1: # type: ignore[union-attr]
|
| 41 |
+
logger.info("The value 'end' is not a scalar.")
|
| 42 |
+
return False
|
| 43 |
+
if axes_const.numpy().size != 1: # type: ignore[union-attr]
|
| 44 |
+
logger.info("The value 'axis' is not a scalar.")
|
| 45 |
+
return False
|
| 46 |
+
if steps_const.numpy().size != 1: # type: ignore[union-attr]
|
| 47 |
+
logger.info("The value 'step' is not a scalar.")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
if steps_const.numpy().item() != 1:
|
| 51 |
+
logger.info("The value 'step' is not 1.")
|
| 52 |
+
return False
|
| 53 |
+
# starts is 0
|
| 54 |
+
if starts_const.numpy().item() != 0:
|
| 55 |
+
logger.info("The value 'start' is not 0.")
|
| 56 |
+
return False
|
| 57 |
+
# In case data.shape is not statically known, we still can tell the slice is redundant if ends is sys.maxsize
|
| 58 |
+
if ends_const.numpy().item() == _INT64_MAX:
|
| 59 |
+
return True
|
| 60 |
+
if data.shape is None or data.shape.is_dynamic(axes_const.numpy().item()):
|
| 61 |
+
logger.info("The value 'data' shape is not statically known.")
|
| 62 |
+
return False
|
| 63 |
+
if ends_const.numpy().item() < data.shape[axes_const.numpy().item()]:
|
| 64 |
+
logger.info("The value 'end' is less than the shape of the specified axis.")
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _identity_to_itself(op, data, **_):
|
| 71 |
+
"""Return the input data as the output."""
|
| 72 |
+
return op.Identity(data)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _potential_redundant_slice(op, data, starts, ends, axes, steps):
|
| 76 |
+
"""To identify a slice op"""
|
| 77 |
+
return op.Slice(data, starts, ends, axes, steps, _outputs=["slice_output"])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_):
|
| 81 |
+
"""Check if the shape of the slice output is the same as the data."""
|
| 82 |
+
if data.shape is None or slice_output.shape is None:
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
if not is_singleton_value(steps, 1):
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
return data.shape == slice_output.shape
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Register the rewrite rules
|
| 92 |
+
remove_redundant_slice = RewriteRule(
|
| 93 |
+
_potential_redundant_slice,
|
| 94 |
+
_identity_to_itself,
|
| 95 |
+
_check_if_redundant_slice,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
remove_redundant_slice2 = RewriteRule(
|
| 99 |
+
_potential_redundant_slice,
|
| 100 |
+
_identity_to_itself,
|
| 101 |
+
_same_shape,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# NOTE: The second rule subsumes the first one. So, we may be able to remove the first one,
|
| 105 |
+
# provided shape-inference is run before the rewriter and computes the shape of the slice output.
|
| 106 |
+
|
| 107 |
+
rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2])
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__init__.py
ADDED
|
File without changes
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/evaluation_utils.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/timing_utils.cpython-310.pyc
ADDED
|
Binary file (853 Bytes). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/timing_utils.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import onnx
|
| 6 |
+
|
| 7 |
+
from onnxscript import optimizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def timeit(f, message):
|
| 11 |
+
def timed(*args, **kw):
|
| 12 |
+
ts = time.time()
|
| 13 |
+
result = f(*args, **kw)
|
| 14 |
+
te = time.time()
|
| 15 |
+
print(f"{message} time: {te - ts}")
|
| 16 |
+
return result
|
| 17 |
+
|
| 18 |
+
return timed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
load = timeit(onnx.load, "Load")
|
| 22 |
+
|
| 23 |
+
save = timeit(onnx.save, "Save")
|
| 24 |
+
|
| 25 |
+
infer = timeit(onnx.shape_inference.infer_shapes, "Infer")
|
| 26 |
+
|
| 27 |
+
fold_constants = timeit(optimizer.fold_constants, "Fold Constants")
|
| 28 |
+
|
| 29 |
+
remove_unused = timeit(optimizer.remove_unused_nodes, "Remove Unused")
|
| 30 |
+
|
| 31 |
+
optimize = timeit(optimizer.optimize, "Optimize")
|
| 32 |
+
|
| 33 |
+
# rewrite = timeit(all_rules.apply_to_model, "Rewrite")
|
pythonProject/.venv/Lib/site-packages/onnxscript/utils/utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import onnx
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def normalize_domain(d: str) -> str:
|
| 11 |
+
return "" if d == "ai.onnx" else d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_onnx_domain(d: str) -> bool:
|
| 15 |
+
return normalize_domain(d) == ""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_onnx_op(node: onnx.NodeProto, op_type: str) -> bool:
|
| 19 |
+
return is_onnx_domain(node.domain) and node.op_type == op_type
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_control_flow_op(node: onnx.NodeProto) -> bool:
|
| 23 |
+
return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any:
|
| 27 |
+
matching = [x for x in node.attribute if x.name == attr_name]
|
| 28 |
+
if len(matching) > 1:
|
| 29 |
+
raise ValueError(f"Node has multiple attributes with name {attr_name}")
|
| 30 |
+
if len(matching) < 1:
|
| 31 |
+
return default
|
| 32 |
+
return onnx.helper.get_attribute_value(matching[0])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto:
|
| 36 |
+
type = onnx.TypeProto()
|
| 37 |
+
type.tensor_type.elem_type = initializer.data_type
|
| 38 |
+
dims = type.tensor_type.shape.dim
|
| 39 |
+
for dim in initializer.dims:
|
| 40 |
+
dims.add().dim_value = dim
|
| 41 |
+
return type
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_constant_node_value(node: onnx.NodeProto, name: str) -> onnx.TensorProto | None:
|
| 45 |
+
if (
|
| 46 |
+
node.op_type != "Constant"
|
| 47 |
+
or node.domain not in {"", "ai.onnx"}
|
| 48 |
+
or len(node.attribute) != 1
|
| 49 |
+
):
|
| 50 |
+
return None
|
| 51 |
+
attr = node.attribute[0]
|
| 52 |
+
if attr.ref_attr_name:
|
| 53 |
+
return None
|
| 54 |
+
attr_name = attr.name
|
| 55 |
+
value = onnx.helper.get_attribute_value(attr)
|
| 56 |
+
|
| 57 |
+
if isinstance(value, onnx.TensorProto):
|
| 58 |
+
# Two names exist in this case: we use tensorproto as is (with original name)
|
| 59 |
+
return value
|
| 60 |
+
shape: list[int]
|
| 61 |
+
if attr_name == "value_int":
|
| 62 |
+
dtype = onnx.TensorProto.INT64
|
| 63 |
+
shape = []
|
| 64 |
+
value = [value]
|
| 65 |
+
elif attr_name == "value_float":
|
| 66 |
+
dtype = onnx.TensorProto.FLOAT
|
| 67 |
+
shape = []
|
| 68 |
+
value = [value]
|
| 69 |
+
elif attr_name == "value_string":
|
| 70 |
+
dtype = onnx.TensorProto.STRING
|
| 71 |
+
shape = []
|
| 72 |
+
value = [value]
|
| 73 |
+
elif attr_name == "value_ints":
|
| 74 |
+
dtype = onnx.TensorProto.INT64
|
| 75 |
+
shape = [len(value)]
|
| 76 |
+
elif attr_name == "value_floats":
|
| 77 |
+
dtype = onnx.TensorProto.FLOAT
|
| 78 |
+
shape = [len(value)]
|
| 79 |
+
elif attr_name == "value_strings":
|
| 80 |
+
dtype = onnx.TensorProto.STRING
|
| 81 |
+
shape = [len(value)]
|
| 82 |
+
else:
|
| 83 |
+
return None # sparse tensors not handled
|
| 84 |
+
return onnx.helper.make_tensor(name, dtype, shape, value)
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__init__.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ConvertVersionPass",
|
| 7 |
+
"convert_version",
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
import onnx
|
| 13 |
+
import onnx_ir.passes.common as common_passes
|
| 14 |
+
|
| 15 |
+
from onnxscript import ir
|
| 16 |
+
from onnxscript.version_converter import _c_api_utils, _version_converter
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConvertVersionPass(ir.passes.InPlacePass):
|
| 22 |
+
"""Convert the model to the specified ONNX opset version.
|
| 23 |
+
|
| 24 |
+
This pass leverages the onnxscript version converter to convert the model. If
|
| 25 |
+
the conversion is not supported, it falls back to the onnx C API to convert
|
| 26 |
+
the model. This pass is in-place.
|
| 27 |
+
|
| 28 |
+
The pass is an no-op if the c-api fails.
|
| 29 |
+
|
| 30 |
+
Attributes:
|
| 31 |
+
target_version: The target ONNX opset version to convert the model to.
|
| 32 |
+
fallback: Whether to fallback to the onnx version converter if the
|
| 33 |
+
target version is not supported. Default is False.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, target_version: int, fallback: bool = False) -> None:
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.target_version = target_version
|
| 39 |
+
self.fallback = fallback
|
| 40 |
+
self.convert_pass = ir.passes.Sequential(
|
| 41 |
+
common_passes.InlinePass(),
|
| 42 |
+
_ConvertVersionPassRequiresInline(
|
| 43 |
+
target_version=target_version,
|
| 44 |
+
fallback=fallback,
|
| 45 |
+
),
|
| 46 |
+
common_passes.RemoveUnusedNodesPass(),
|
| 47 |
+
common_passes.RemoveUnusedFunctionsPass(),
|
| 48 |
+
common_passes.RemoveUnusedOpsetsPass(),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
| 52 |
+
return self.convert_pass(model)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
|
| 56 |
+
"""Convert the model to the specified ONNX opset version.
|
| 57 |
+
|
| 58 |
+
This pass leverages the onnxscript version converter to convert the model. If
|
| 59 |
+
the conversion is not supported, it falls back to the onnx C API to convert
|
| 60 |
+
the model. This pass is in-place.
|
| 61 |
+
|
| 62 |
+
The pass is an no-op if the c-api fails.
|
| 63 |
+
|
| 64 |
+
Attributes:
|
| 65 |
+
target_version: The target ONNX opset version to convert the model to.
|
| 66 |
+
fallback: Whether to fallback to the onnx version converter if the
|
| 67 |
+
target version is not supported.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, target_version: int, fallback: bool) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.target_version = target_version
|
| 73 |
+
self.fallback = fallback
|
| 74 |
+
|
| 75 |
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
| 76 |
+
if model.functions:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"The model contains functions. The version conversion pass does not support "
|
| 79 |
+
"functions. Please use `common_passes.InlinePass` to inline the "
|
| 80 |
+
f"functions before applying this pass ({self.__class__.__name__})."
|
| 81 |
+
)
|
| 82 |
+
if "" in model.graph.opset_imports:
|
| 83 |
+
onnx_opset_version = model.graph.opset_imports[""]
|
| 84 |
+
if onnx_opset_version == self.target_version:
|
| 85 |
+
# No need to convert the version
|
| 86 |
+
return ir.passes.PassResult(model, False)
|
| 87 |
+
|
| 88 |
+
# When fallback is disabled, always use the onnxscript version converter;
|
| 89 |
+
# When fallback is enabled, use the onnxscript version converter
|
| 90 |
+
# if the target version is supported. Otherwise, use the onnx C API
|
| 91 |
+
# to convert the model.
|
| 92 |
+
if not self.fallback or _version_converter.version_supported(
|
| 93 |
+
model, self.target_version
|
| 94 |
+
):
|
| 95 |
+
_version_converter.convert_version(
|
| 96 |
+
model,
|
| 97 |
+
target_version=self.target_version,
|
| 98 |
+
)
|
| 99 |
+
return ir.passes.PassResult(model, True)
|
| 100 |
+
|
| 101 |
+
if not self.fallback:
|
| 102 |
+
logger.warning(
|
| 103 |
+
"The model version conversion is not supported by the onnxscript version converter "
|
| 104 |
+
"and fallback is disabled. The model was not modified"
|
| 105 |
+
" (target version: %d). "
|
| 106 |
+
"Set fallback=True to enable fallback to the onnx c-api version converter.",
|
| 107 |
+
self.target_version,
|
| 108 |
+
)
|
| 109 |
+
return ir.passes.PassResult(model, False)
|
| 110 |
+
else:
|
| 111 |
+
logger.warning(
|
| 112 |
+
"The model version conversion is not supported by the onnxscript version converter "
|
| 113 |
+
"and fallback is enabled. The model will be converted using the onnx C API "
|
| 114 |
+
"(target version: %d).",
|
| 115 |
+
self.target_version,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# If the onnxscript version converter does not support the conversion,
|
| 119 |
+
# we can use the onnx C API to convert the model
|
| 120 |
+
def _partial_convert_version(proto: onnx.ModelProto) -> onnx.ModelProto:
|
| 121 |
+
"""Partial function to check the model."""
|
| 122 |
+
return onnx.version_converter.convert_version(
|
| 123 |
+
proto, target_version=self.target_version
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
converted_proto = _c_api_utils.call_onnx_api(
|
| 128 |
+
func=_partial_convert_version, model=model
|
| 129 |
+
)
|
| 130 |
+
except Exception as e: # pylint: disable=broad-exception-caught
|
| 131 |
+
logger.warning(
|
| 132 |
+
"Failed to convert the model to the target version %d using the ONNX C API. "
|
| 133 |
+
"The model was not modified",
|
| 134 |
+
self.target_version,
|
| 135 |
+
exc_info=e,
|
| 136 |
+
)
|
| 137 |
+
return ir.passes.PassResult(model, False)
|
| 138 |
+
|
| 139 |
+
converted_model = ir.from_proto(converted_proto)
|
| 140 |
+
|
| 141 |
+
# Recover the initializers in the converted model
|
| 142 |
+
for input in converted_model.graph.inputs:
|
| 143 |
+
if input.name in model.graph.initializers:
|
| 144 |
+
input.const_value = model.graph.initializers[input.name].const_value
|
| 145 |
+
converted_model.graph.register_initializer(input)
|
| 146 |
+
user_inputs = converted_model.graph.inputs[: len(model.graph.inputs)]
|
| 147 |
+
converted_model.graph.inputs.clear()
|
| 148 |
+
converted_model.graph.inputs.extend(user_inputs)
|
| 149 |
+
|
| 150 |
+
# Return the converted graph to the original model to keep the pass in-place
|
| 151 |
+
model.graph = converted_model.graph
|
| 152 |
+
return ir.passes.PassResult(model, True)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def convert_version(
|
| 156 |
+
model: ir.Model | onnx.ModelProto, target_version: int, fallback=None
|
| 157 |
+
) -> None:
|
| 158 |
+
"""Convert the model to the specified ONNX opset version.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
model: The model to convert.
|
| 162 |
+
target_version: The target ONNX opset version.
|
| 163 |
+
fallback: Whether to fallback to the onnx version converter if the
|
| 164 |
+
target version is not supported. Default is False.
|
| 165 |
+
"""
|
| 166 |
+
if isinstance(model, onnx.ModelProto):
|
| 167 |
+
model_proto = model
|
| 168 |
+
model = ir.from_proto(model)
|
| 169 |
+
else:
|
| 170 |
+
model_proto = None
|
| 171 |
+
|
| 172 |
+
assert isinstance(model, ir.Model)
|
| 173 |
+
ConvertVersionPass(target_version=target_version, fallback=fallback)(model)
|
| 174 |
+
|
| 175 |
+
if model_proto is not None:
|
| 176 |
+
# Update the model proto in-place
|
| 177 |
+
model_proto.graph.Clear()
|
| 178 |
+
del model_proto.functions[:]
|
| 179 |
+
model_proto.graph.CopyFrom(ir.to_proto(model.graph))
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.72 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_c_api_utils.cpython-310.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/__pycache__/_version_converter.cpython-310.pyc
ADDED
|
Binary file (9.96 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_c_api_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
"""Utilities for interfacing with onnx C APIs."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import TYPE_CHECKING, Callable, TypeVar
|
| 9 |
+
|
| 10 |
+
from onnxscript import ir
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
import onnx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
# Temporarily remove initializers larger than this size to keep model size down
|
| 18 |
+
# for the onnx.shape_inference call because it needs to serialize the model
|
| 19 |
+
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
|
| 20 |
+
_R = TypeVar("_R")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
|
| 24 |
+
"""Call an ONNX C API function by temporarily removing initializers.
|
| 25 |
+
|
| 26 |
+
This is necessary because the ONNX C API does not support large models
|
| 27 |
+
with initializers that have large tensor values. The input model is left
|
| 28 |
+
unchanged no matter the call succeeds or not.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
func: Partially applied function that takes a model proto and returns anything.
|
| 32 |
+
model: The IR model to pass to the API function.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
The resulting ModelProto that contains the result of the API call.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Store the original initializer values so they can be restored
|
| 39 |
+
initializer_values = tuple(model.graph.initializers.values())
|
| 40 |
+
tensors = {v.name: v.const_value for v in initializer_values}
|
| 41 |
+
original_inputs_len = len(model.graph.inputs)
|
| 42 |
+
|
| 43 |
+
# Turn the initializers into inputs and clear the initializers
|
| 44 |
+
# to limit the model size
|
| 45 |
+
for initializer in initializer_values:
|
| 46 |
+
# Make sure the initializer has its shape/type set
|
| 47 |
+
assert initializer.const_value is not None
|
| 48 |
+
if initializer.shape is None:
|
| 49 |
+
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
|
| 50 |
+
if initializer.dtype is None:
|
| 51 |
+
initializer.dtype = initializer.const_value.dtype
|
| 52 |
+
if initializer not in model.graph.inputs:
|
| 53 |
+
model.graph.inputs.append(initializer)
|
| 54 |
+
if initializer.const_value.size > _BIG_TENSOR_SIZE_LIMIT:
|
| 55 |
+
# Temporarily remove the initializer value to reduce model size
|
| 56 |
+
# for onnx.shape_inference
|
| 57 |
+
initializer.const_value = None
|
| 58 |
+
assert initializer.name is not None
|
| 59 |
+
model.graph.initializers.pop(initializer.name)
|
| 60 |
+
|
| 61 |
+
proto = ir.serde.serialize_model(model)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
# Call the ONNX C API function
|
| 65 |
+
result = func(proto)
|
| 66 |
+
finally:
|
| 67 |
+
# Restore the original initializer values so the model is unchanged
|
| 68 |
+
for initializer in initializer_values:
|
| 69 |
+
initializer.const_value = tensors[initializer.name]
|
| 70 |
+
model.graph.register_initializer(initializer)
|
| 71 |
+
|
| 72 |
+
# Restore the original inputs
|
| 73 |
+
inputs = model.graph.inputs[:original_inputs_len]
|
| 74 |
+
model.graph.inputs.clear()
|
| 75 |
+
model.graph.inputs.extend(inputs)
|
| 76 |
+
|
| 77 |
+
return result
|
pythonProject/.venv/Lib/site-packages/onnxscript/version_converter/_version_converter.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
"""Convert the model to the specified ONNX opset version."""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import dataclasses
|
| 8 |
+
import functools
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Callable, Sequence, Union
|
| 11 |
+
|
| 12 |
+
import onnx_ir.convenience as ir_convenience
|
| 13 |
+
|
| 14 |
+
import onnxscript.ir._tape as _tape
|
| 15 |
+
from onnxscript import ir
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SUPPORTED_MAX_ONNX_OPSET = 23
|
| 21 |
+
SUPPORTED_MIN_ONNX_OPSET = 18
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_onnx_opset_version(model: ir.Model) -> int | None:
|
| 25 |
+
"""Get the ONNX opset version imported by the model."""
|
| 26 |
+
model_version1 = model.opset_imports.get("")
|
| 27 |
+
model_version2 = model.opset_imports.get("ai.onnx")
|
| 28 |
+
if model_version1 is not None and model_version2 is not None:
|
| 29 |
+
if model_version1 != model_version2:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
|
| 32 |
+
)
|
| 33 |
+
return model_version1 or model_version2
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _set_onnx_opset_version(model: ir.Model, version: int) -> None:
|
| 37 |
+
"""Set the ONNX opset version imported by the model."""
|
| 38 |
+
if "ai.onnx" in model.opset_imports:
|
| 39 |
+
del model.opset_imports["ai.onnx"]
|
| 40 |
+
model.opset_imports[""] = version
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class VersionConverterError(RuntimeError):
|
| 44 |
+
"""Raised when an node's version cannot be upgraded/downgraded successfully."""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclasses.dataclass
|
| 48 |
+
class Replacement:
|
| 49 |
+
"""A replacement for a node in the graph."""
|
| 50 |
+
|
| 51 |
+
new_outputs: Sequence[ir.Value]
|
| 52 |
+
new_nodes: Sequence[ir.Node]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# A version-adapter function takes a node, a RewriterContext and returns
|
| 56 |
+
# a Replacement for the node or None (if no replacement is needed).
|
| 57 |
+
|
| 58 |
+
RewriterContext = _tape.Builder
|
| 59 |
+
ReturnValue = Union[Sequence[ir.Value], ir.Value, None]
|
| 60 |
+
AdapterFunction = Callable[[ir.Node, RewriterContext], ReturnValue]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def version_supported(model: ir.Model, target_version: int) -> bool:
|
| 64 |
+
"""Check if the target version is supported by the current version."""
|
| 65 |
+
if "" in model.graph.opset_imports:
|
| 66 |
+
current_version = model.graph.opset_imports[""]
|
| 67 |
+
else:
|
| 68 |
+
return True
|
| 69 |
+
return (
|
| 70 |
+
SUPPORTED_MIN_ONNX_OPSET
|
| 71 |
+
<= current_version
|
| 72 |
+
<= target_version
|
| 73 |
+
<= SUPPORTED_MAX_ONNX_OPSET
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AdapterRegistry:
|
| 78 |
+
"""A class that maintains a registry of adapters for ops."""
|
| 79 |
+
|
| 80 |
+
def __init__(self):
|
| 81 |
+
self.op_adapters: dict[tuple[str, str, int, bool], AdapterFunction] = {}
|
| 82 |
+
|
| 83 |
+
def lookup_adapters(
|
| 84 |
+
self,
|
| 85 |
+
domain: str,
|
| 86 |
+
opname: str,
|
| 87 |
+
original_version: int,
|
| 88 |
+
up_conversion: bool = True,
|
| 89 |
+
) -> AdapterFunction | None:
|
| 90 |
+
adapter_func = self.op_adapters.get((domain, opname, original_version, up_conversion))
|
| 91 |
+
if adapter_func is not None:
|
| 92 |
+
return adapter_func
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def register(
|
| 96 |
+
self, opname: str, domain: str = "", node_version=None, up_conversion=True
|
| 97 |
+
) -> Callable[[AdapterFunction], AdapterFunction]:
|
| 98 |
+
"""Register an adapter based on the domain, operator type, node version and whether to upgrade/downgrade node version"""
|
| 99 |
+
|
| 100 |
+
def decorator(function: AdapterFunction) -> AdapterFunction:
|
| 101 |
+
@functools.wraps(function)
|
| 102 |
+
def wrapped_function(*args, **kwargs):
|
| 103 |
+
return function(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
self.op_adapters[(domain, opname, node_version, up_conversion)] = function
|
| 106 |
+
return wrapped_function
|
| 107 |
+
|
| 108 |
+
return decorator
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
registry: AdapterRegistry = AdapterRegistry()
|
| 112 |
+
|
| 113 |
+
register = registry.register
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_input(node: ir.Node, index: int) -> ir.Value | None:
|
| 117 |
+
if index < len(node.inputs):
|
| 118 |
+
return node.inputs[index]
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> int | None:
|
| 123 |
+
if name in node.attributes:
|
| 124 |
+
attr = node.attributes[name]
|
| 125 |
+
if not isinstance(attr, ir.Attr):
|
| 126 |
+
return None
|
| 127 |
+
attr_val = attr.value
|
| 128 |
+
if isinstance(attr_val, int):
|
| 129 |
+
return attr_val
|
| 130 |
+
# This is an invalid model: attribute has invalid/unexpected type.
|
| 131 |
+
# For now, we just return None. We could raise an error too.
|
| 132 |
+
return None
|
| 133 |
+
return default
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> str | None:
|
| 137 |
+
if name in node.attributes:
|
| 138 |
+
attr = node.attributes[name]
|
| 139 |
+
if not isinstance(attr, ir.Attr):
|
| 140 |
+
return None
|
| 141 |
+
attr_val = attr.value
|
| 142 |
+
if isinstance(attr_val, str):
|
| 143 |
+
return attr_val
|
| 144 |
+
# This is an invalid model: attribute has invalid/unexpected type.
|
| 145 |
+
# For now, we just return None. We could raise an error too.
|
| 146 |
+
return None
|
| 147 |
+
return default
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
## Op-specific adapters
|
| 151 |
+
|
| 152 |
+
# Opset 19 -> 20
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@register("DFT", node_version=19, up_conversion=True)
|
| 156 |
+
def dft_19_20(node: ir.Node, op):
|
| 157 |
+
input = node.inputs[0]
|
| 158 |
+
inverse = _get_int_attribute(node, "inverse", 0)
|
| 159 |
+
onesided = _get_int_attribute(node, "onesided", 0)
|
| 160 |
+
axis = _get_int_attribute(node, "axis", None)
|
| 161 |
+
if axis is not None:
|
| 162 |
+
axis_value = op.Constant(value_int=axis)
|
| 163 |
+
return op.DFT(input, axis_value, inverse=inverse, onesided=onesided)
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@register("GridSample", node_version=19, up_conversion=True)
|
| 168 |
+
def gridsample_19_20(node: ir.Node, op):
|
| 169 |
+
x = node.inputs[0]
|
| 170 |
+
grid = node.inputs[1]
|
| 171 |
+
align_corners = _get_int_attribute(node, "align_corners", 0)
|
| 172 |
+
mode = _get_str_attribute(node, "mode", "linear")
|
| 173 |
+
padding_mode = _get_str_attribute(node, "padding_mode", "zeros")
|
| 174 |
+
if mode == "bilinear":
|
| 175 |
+
return op.GridSample(
|
| 176 |
+
x, grid, align_corners=align_corners, mode="linear", padding_mode=padding_mode
|
| 177 |
+
)
|
| 178 |
+
elif mode == "bicubic":
|
| 179 |
+
return op.GridSample(
|
| 180 |
+
x, grid, align_corners=align_corners, mode="cubic", padding_mode=padding_mode
|
| 181 |
+
)
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Opset 20 -> 21
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@register("GroupNormalization", node_version=20, up_conversion=True)
|
| 189 |
+
def groupnormalization_20_21(node: ir.Node, op):
|
| 190 |
+
x = _get_input(node, 0)
|
| 191 |
+
scale = _get_input(node, 1)
|
| 192 |
+
bias = _get_input(node, 2)
|
| 193 |
+
if x is None or scale is None or bias is None:
|
| 194 |
+
raise VersionConverterError(f"Missing input for {node}")
|
| 195 |
+
|
| 196 |
+
x_shape = x.shape
|
| 197 |
+
if x_shape is None:
|
| 198 |
+
raise VersionConverterError(f"Missing required shape for {x}")
|
| 199 |
+
num_channels = x_shape[1]
|
| 200 |
+
if not isinstance(num_channels, int):
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
scale_shape = scale.shape
|
| 204 |
+
bias_shape = bias.shape
|
| 205 |
+
if scale_shape is None or bias_shape is None:
|
| 206 |
+
return None
|
| 207 |
+
if not isinstance(scale_shape[0], int) or not isinstance(bias_shape[0], int):
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
num_groups = _get_int_attribute(node, "num_groups", None)
|
| 211 |
+
if num_groups is None:
|
| 212 |
+
raise VersionConverterError("Missing required attribute: num_groups")
|
| 213 |
+
if (
|
| 214 |
+
num_groups != num_channels
|
| 215 |
+
and num_groups == scale_shape[0]
|
| 216 |
+
and num_groups == bias_shape[0]
|
| 217 |
+
):
|
| 218 |
+
reshape_1_sizes = op.Constant(value_ints=[-1, 1])
|
| 219 |
+
reshape_2_sizes = op.Constant(value_ints=[-1])
|
| 220 |
+
c_div = int(num_channels / num_groups)
|
| 221 |
+
expand_sizes = op.Constant(value_ints=[1, c_div])
|
| 222 |
+
|
| 223 |
+
# Modify scale input
|
| 224 |
+
scale_reshape_1 = op.Reshape(scale, reshape_1_sizes)
|
| 225 |
+
scale_expand = op.Expand(scale_reshape_1, expand_sizes)
|
| 226 |
+
scale_reshape_2 = op.Reshape(scale_expand, reshape_2_sizes)
|
| 227 |
+
|
| 228 |
+
# Modify bias input
|
| 229 |
+
bias_reshape_1 = op.Reshape(bias, reshape_1_sizes)
|
| 230 |
+
bias_expand = op.Expand(bias_reshape_1, expand_sizes)
|
| 231 |
+
bias_reshape_2 = op.Reshape(bias_expand, reshape_2_sizes)
|
| 232 |
+
|
| 233 |
+
return op.GroupNormalization(x, scale_reshape_2, bias_reshape_2, num_groups=num_groups)
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class _VersionConverter:
|
| 238 |
+
def __init__(self, target_version: int):
|
| 239 |
+
self._target_version = target_version
|
| 240 |
+
|
| 241 |
+
def process_node(
|
| 242 |
+
self, node: ir.Node, from_version: int, up_conversion: bool = True
|
| 243 |
+
) -> Replacement | None:
|
| 244 |
+
assert node.domain == ""
|
| 245 |
+
adapter = registry.lookup_adapters(
|
| 246 |
+
node.domain, node.op_type, from_version, up_conversion
|
| 247 |
+
)
|
| 248 |
+
if adapter is None:
|
| 249 |
+
return None
|
| 250 |
+
context = RewriterContext()
|
| 251 |
+
output = adapter(node, context)
|
| 252 |
+
if output is not None:
|
| 253 |
+
if isinstance(output, ir.Value):
|
| 254 |
+
output = [output]
|
| 255 |
+
return Replacement(output, context.nodes)
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
|
| 259 |
+
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)
|
| 260 |
+
|
| 261 |
+
ir_convenience.replace_nodes_and_values(
|
| 262 |
+
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def visit_attribute(self, attr: ir.Attr) -> None:
|
| 266 |
+
if attr.is_ref():
|
| 267 |
+
return
|
| 268 |
+
if attr.type == ir.AttributeType.GRAPH:
|
| 269 |
+
self.visit_graph(attr.as_graph())
|
| 270 |
+
elif attr.type == ir.AttributeType.GRAPHS:
|
| 271 |
+
for graph in attr.as_graphs():
|
| 272 |
+
self.visit_graph(graph)
|
| 273 |
+
|
| 274 |
+
def visit_node(
|
| 275 |
+
self,
|
| 276 |
+
node: ir.Node,
|
| 277 |
+
root: ir.Graph | ir.Function,
|
| 278 |
+
from_version: int,
|
| 279 |
+
up_conversion: bool = True,
|
| 280 |
+
) -> None:
|
| 281 |
+
if up_conversion:
|
| 282 |
+
to_version = from_version + 1
|
| 283 |
+
else:
|
| 284 |
+
to_version = from_version - 1
|
| 285 |
+
replacement = self.process_node(node, from_version, up_conversion)
|
| 286 |
+
if replacement is None:
|
| 287 |
+
# No change. Process attributes.
|
| 288 |
+
for attr in node.attributes.values():
|
| 289 |
+
self.visit_attribute(attr)
|
| 290 |
+
node.version = to_version
|
| 291 |
+
else:
|
| 292 |
+
for new_node in replacement.new_nodes:
|
| 293 |
+
# TODO: control-flow
|
| 294 |
+
new_node.version = to_version
|
| 295 |
+
self.replace_node(node, replacement, root)
|
| 296 |
+
|
| 297 |
+
def visit_graph(self, graph: ir.Graph) -> None:
|
| 298 |
+
for node in graph:
|
| 299 |
+
if node.domain != "":
|
| 300 |
+
continue
|
| 301 |
+
node_version = node.version or self._default_onnx_opset
|
| 302 |
+
if node_version is None:
|
| 303 |
+
raise VersionConverterError(f"Node {node} has no version.")
|
| 304 |
+
# Iterate each node from current node version -> target version
|
| 305 |
+
# and updating node based on the correct adapter
|
| 306 |
+
# Up-conversion [ver->ver+1] or down-conversion [ver->ver-1]
|
| 307 |
+
# TODO(shubhambhokare1): Remove once down-conversion adapters are supoorted
|
| 308 |
+
if self._target_version < node_version:
|
| 309 |
+
raise VersionConverterError(
|
| 310 |
+
f"Target opset: {self._target_version} less than node version: {node.version}, "
|
| 311 |
+
"downstream version conversion not currently handled."
|
| 312 |
+
)
|
| 313 |
+
for from_version in range(node_version, self._target_version):
|
| 314 |
+
try:
|
| 315 |
+
self.visit_node(node, graph, from_version, up_conversion=True)
|
| 316 |
+
except VersionConverterError as e:
|
| 317 |
+
logger.warning(
|
| 318 |
+
"Skipping version conversion for node %s due to exception: %s",
|
| 319 |
+
node.op_type,
|
| 320 |
+
e,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
def visit_model(self, model: ir.Model) -> None:
|
| 324 |
+
self._default_onnx_opset = _get_onnx_opset_version(model)
|
| 325 |
+
self.visit_graph(model.graph)
|
| 326 |
+
_set_onnx_opset_version(model, self._target_version)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def convert_version(model: ir.Model, target_version: int) -> None:
|
| 330 |
+
"""Convert the model to the specified ONNX opset version."""
|
| 331 |
+
if (target_version > SUPPORTED_MAX_ONNX_OPSET) or (
|
| 332 |
+
target_version < SUPPORTED_MIN_ONNX_OPSET
|
| 333 |
+
):
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"Target opset version {target_version} is not supported. "
|
| 336 |
+
f"Supported range: {SUPPORTED_MIN_ONNX_OPSET} to {SUPPORTED_MAX_ONNX_OPSET}."
|
| 337 |
+
)
|
| 338 |
+
version_converter = _VersionConverter(target_version=target_version)
|
| 339 |
+
version_converter.visit_model(model)
|