xiaoanyu123's picture
Add files using upload-large-folder tool
6a22ec9 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any
import numpy as np
import onnx
import onnxruntime as ort
from onnxscript import ir
def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
feeds: dict[str, Any] = {}
for input in model.graph.input:
input_type = input.type.tensor_type
shape = tuple(input_type.shape.dim)
if not all(hasattr(d, "dim_value") for d in shape):
raise ValueError(f"Input {input.name} has dynamic shape dimensions.")
shape = tuple(d.dim_value for d in shape)
if input_type.elem_type == onnx.TensorProto.FLOAT:
if shape:
feeds[input.name] = np.random.randn(*shape).astype(np.float32)
else:
feeds[input.name] = np.random.randn(1).astype(np.float32)
else:
raise ValueError(f"Not implemented for input type {input_type.elem_type}")
return feeds
def assert_numerically_equal(
original_model_proto: onnx.ModelProto | ir.Model,
rewritten_model_proto: onnx.ModelProto | ir.Model,
args: tuple[Any, ...] | dict[str, Any],
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
rtol: float = 1,
atol: float = 1e-3,
):
"""Assert that the two models are numerically equal.
Args:
original_model_proto: The original model proto or ir.Model.
rewritten_model_proto: The rewritten by the rules model proto or ir.Model.
args: The positional arguments to pass to the model.
ort_optimization_level: Onnxruntime optimization level.
rtol: Relative tolerance.
atol: Absolute tolerance.
"""
if isinstance(original_model_proto, ir.Model):
original_model_proto = ir.serde.serialize_model(original_model_proto)
if isinstance(rewritten_model_proto, ir.Model):
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)
if isinstance(args, dict):
original_proto_ort_inputs = args
the_rewritten_proto_ort_inputs = args
else:
original_proto_ort_inputs = {
k.name: v for k, v in zip(original_model_proto.graph.input, args)
}
the_rewritten_proto_ort_inputs = {
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
}
original_proto_ort_inference_session = _ort_session_initializer(
original_model_proto.SerializeToString(), ort_optimization_level
)
run_options = ort.RunOptions()
run_options.log_severity_level = 3 # 3: Error
original_outputs = original_proto_ort_inference_session.run(
None, original_proto_ort_inputs, run_options=run_options
)
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
rewritten_model_proto.SerializeToString(), ort_optimization_level
)
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
None, the_rewritten_proto_ort_inputs, run_options=run_options
)
np.testing.assert_allclose(
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True
)
def _ort_session_initializer(
model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel
) -> ort.InferenceSession:
"""Initialize an ONNX Runtime inference session with the specified model."""
import onnxruntime as ort
session_options = ort.SessionOptions()
session_options.log_severity_level = 3 # 3: Error
session_options.graph_optimization_level = ort_optimization_level
possible_providers = (
"CUDAExecutionProvider",
"CPUExecutionProvider",
)
available_providers = set(ort.get_available_providers())
providers = [
provider for provider in possible_providers if provider in available_providers
]
return ort.InferenceSession(model, providers=providers, sess_options=session_options)