xiaoanyu123's picture
Add files using upload-large-folder tool
08157a5 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations
import dataclasses
import io
import logging
import warnings
from typing import Any, Optional, Protocol, Sequence, Union
import onnx
from onnx import ValueInfoProto, helper
from onnx.defs import onnx_opset_version
import onnxscript
from onnxscript import type_annotation as ta
from onnxscript import values
from onnxscript._internal import version_utils
from onnxscript.onnx_types import ONNXType
from onnxscript.sourceinfo import SourceInfo
# A simple IR (Function, Stmt, Attr, Var):
logger = logging.getLogger("onnxscript")
def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str):
"""Formats a sequence of objects into a string."""
return prefix + sep.join([formatter(x) for x in seq]) + suffix
def select_ir_version(version: int, domain: str = "") -> int:
"""Selects a suitable ONNX ir_version for a given opset version."""
if domain == "":
domain = "ai.onnx"
if (domain, version) not in helper.OP_SET_ID_VERSION_MAP:
return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx")
return helper.OP_SET_ID_VERSION_MAP[domain, version]
class IRType:
def __init__(self):
self.onnx_type = onnx.TypeProto()
def to_type_proto(self):
return self.onnx_type
def __repr__(self) -> str:
return "IRType()"
class IRTensorType(IRType):
def __init__(self, elem_type: onnx.TensorProto.DataType) -> None:
super().__init__()
self.onnx_type.tensor_type.elem_type = elem_type
def __repr__(self) -> str:
return f"IRTensorType({self.onnx_type.tensor_type.elem_type})"
class IRTypeLike(Protocol):
def to_type_proto(self) -> onnx.TypeProto:
"""Converts IR type representation to onnx.TypeProto"""
class IRVar:
"""A variable (representing a formal parameter)."""
def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None:
if not isinstance(varname, str):
raise TypeError(f"varname must be a string not {type(varname)!r}.")
self.name = varname
self.info = sourceinfo
self.typeinfo = typeinfo
def __str__(self):
return self.name
def __repr__(self):
return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})"
def typed_str(self):
return f"{self.name} : {self.typeinfo}"
def to_value_info(self, use_default_type: bool = True):
"""Converts the content of this class into :class:`onnx.ValueInfoProto`.
Args:
use_default_type: if True, use a default type if an explicit type
is not known. Otherwise, returns a ValueInfoProto without type.
Returns:
an instance of :class:`onnx.ValueInfoProto`
"""
if self.name is None:
raise ValueError(self.info.msg("name cannot be None."))
value_info_proto = ValueInfoProto()
value_info_proto.name = self.name
if self.typeinfo is not None:
value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto())
elif use_default_type:
value_info_proto.type.CopyFrom(IRType().to_type_proto())
return value_info_proto
def _opt_var_to_str(x):
return "" if x is None else str(x)
class IRAttributeValue:
"""An attribute value (representing an actual parameter).
Attributes:
name: The name of the attribute.
type: The type of the attribute.
attr_proto: The attribute proto.
"""
def __init__(self, attrproto: onnx.AttributeProto) -> None:
self.attr_proto = attrproto
def __str__(self):
if self.attr_proto.HasField("ref_attr_name"):
return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}"
# self.name + " = " + self.value
return helper.printable_attribute(self.attr_proto)
@property
def name(self) -> str:
return self.attr_proto.name
@property
def type(self) -> onnx.AttributeProto.AttributeType:
return self.attr_proto.type
@dataclasses.dataclass(frozen=True)
class IRAttributeParameter:
"""An attribute parameter (representing a formal parameter).
It may or may not carry a default value.
Attributes:
name: The name of the attribute.
type: The type of the attribute.
default_value: The default value of the attribute.
has_default: Whether the attribute has a default value.
attr_proto: The attribute proto.
"""
name: str
type: onnx.AttributeProto.AttributeType
default_value: str | int | float | None = None
# TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType.
def __str__(self):
if self.has_default:
return helper.printable_attribute(self.attr_proto)
# TODO(justinchuby): Include a readable type name.
return self.name
@property
def has_default(self):
return self.default_value is not None
@property
def attr_proto(self) -> onnx.AttributeProto:
if not self.has_default:
raise ValueError(
"Attribute has no default value. Only attributes with default "
"values can be converted to AttributeProto."
)
if version_utils.onnx_older_than("1.15"):
# TODO(after 1.14 is deprecated): Remove this branch.
# Argument 'attr_type' was added after version 1.14.
return helper.make_attribute(self.name, self.default_value)
# pylint: disable=unexpected-keyword-arg
return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg]
# pylint: enable=unexpected-keyword-arg
class IRStmt:
def __init__(
self,
result: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
sub_functions=None,
) -> None:
if not isinstance(callee, values.Op):
raise TypeError(f"Unexpected type {type(callee)} for callee.")
self.result = result
self.callee = callee
self.args = args
self.attrs = attrs
self.functions = sub_functions or {}
def __str__(self):
if isinstance(self.result, str):
logger.debug("unexpected str type for self.result where type(self)=%r", type(self))
lhs = ", ".join(self.result)
attrs = ""
if self.attrs:
attrs = _format(self.attrs, "<", ", ", ">")
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
domain = self.callee.opset.domain
opname = self.callee.name
callee = f"{domain}.{opname}" if (domain != "") else opname
return f"{lhs} = {callee} {attrs}{args}"
def debug_print(self):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("%s: %s", type(self), str(self))
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
n = helper.make_node(
self.callee.name,
[_opt_var_to_str(x) for x in self.args],
[str(x) for x in self.result],
domain=self.callee.opset.domain,
name=node_name,
)
for a in self.attrs:
n.attribute.append(a.attr_proto)
return n
@property
def output_names(self) -> Sequence[str]:
"""Returns the list of variables assigned to by this statement."""
return [str(x) for x in self.result]
class IRFunction:
"""Represents a function in the IR."""
def __init__(self, name: str, domain: str = "") -> None:
self.domain = domain
self.name = name
self.outputs: list[IRVar] = []
self.stmts: list[IRStmt] = []
self.called_functions: dict[str, onnx.FunctionProto] = {}
self.docstring: str = ""
# a dictionary of nested function-definitions
self.nested_functions: dict[str, IRFunction] = {}
self.outer_scope_variables: dict[Any, Any] = {}
self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = []
@property
def assigned_names(self) -> Sequence[str]:
"""Returns the list of variables assigned to by this function."""
return [v for stmt in self.stmts for v in stmt.output_names]
@property
def inputs(self) -> Sequence[IRVar]:
return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)]
@property
def attrs(self) -> Sequence[IRAttributeParameter]:
return [
attr
for attr in self.ordered_inputs_and_attrs
if isinstance(attr, IRAttributeParameter)
]
def __str__(self):
attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else ""
inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")")
outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")")
stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n")
return f"{self.name} {attrs}{inputs} => {outputs}{stmts}"
def append_docstring(self, docstring):
self.docstring += docstring
def append_stmt(self, stmt: IRStmt) -> None:
self.stmts.append(stmt)
def append_input(self, name: IRVar) -> None:
self.ordered_inputs_and_attrs.append(name)
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)
def add_attr_parameter(self, attr: IRAttributeParameter) -> None:
self.ordered_inputs_and_attrs.append(attr)
def debug_print(self):
if logger.isEnabledFor(logging.DEBUG):
st = io.StringIO()
for s in self.stmts:
for attr in s.attrs:
if attr.attr_proto.HasField("g"):
st.write(helper.printable_graph(attr.attr_proto.g))
st.write("\n")
def add_called_function(self, fun: values.OnnxFunction) -> None:
for name, fct in fun.function_ir.called_functions.items():
if name in self.called_functions:
continue
self.called_functions[name] = fct
if fun.name in self.called_functions:
# Already added.
return
try:
proto = fun.to_function_proto()
except (TypeError, AttributeError) as e:
raise TypeError(f"Issue with type f{type(fun)}.") from e
self.called_functions[fun.name] = proto
def add_nested_function(self, fun: IRFunction) -> None:
self.nested_functions[fun.name] = fun
def to_model_proto(
self,
functions=None,
io_types: Optional[ONNXType] = None,
input_types: Optional[Sequence[ONNXType]] = None,
output_types: Optional[Sequence[ONNXType]] = None,
value_infos: dict[str, ONNXType] | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Converts this instance into a `onnx.ModelProto`.
Args:
functions: A list of functions to include in the model.
By default, all functions called at least once are included.
io_types: When specified, all the inputs/outputs of the model
are set to be of this type.
input_types: When specified, all the inputs of the model
are set to be of the corresponding type in this list.
output_types: When specified, all the outputs of the model
are set to be of the corresponding type in this list.
value_infos: A dictionary mapping intermediate variable names to ONNX types.
Used to set value_info for intermediate variables.
kwargs: Additional parameters given to function :func:`onnx.helper.make_model`.
Returns:
An instance of :class:`onnx.ModelProto`.
"""
value_infos = (
[
onnx.helper.make_value_info(name, type.to_type_proto())
for name, type in value_infos.items()
]
if value_infos
else None
)
graph, sub_functions = self.to_graph_and_functions(
use_default_type=False, value_infos=value_infos
)
if io_types is not None:
for input in graph.input:
if not input.HasField("type"):
input.type.CopyFrom(io_types.to_type_proto())
for output in graph.output:
if not output.HasField("type"):
output.type.CopyFrom(io_types.to_type_proto())
if input_types is not None:
for input, type in zip(graph.input, input_types):
input.type.CopyFrom(type.to_type_proto())
if output_types is not None:
for output, type in zip(graph.output, output_types):
output.type.CopyFrom(type.to_type_proto())
if functions is None:
functions = sub_functions.values()
else:
def to_proto(f):
if isinstance(f, onnx.FunctionProto):
return f
if isinstance(f, onnxscript.OnnxFunction):
return f.to_function_proto()
raise TypeError("Expected a value of type FunctionProto of OnnxFunction")
functions = [to_proto(f) for f in functions]
opsets = {}
for n in self.stmts:
if n.callee.opset.domain not in opsets:
opsets[n.callee.opset.domain] = n.callee.opset.version
for proto in functions:
if proto.domain not in opsets:
opsets[proto.domain] = 1
# TODO(rama): Handle conflicts with appropriate error/warning message.
for opset in proto.opset_import:
if opset.domain not in opsets:
opsets[opset.domain] = opset.version
if "" not in opsets:
# No operator is using the standard opset.
# A default value is given.
opsets[""] = onnx_opset_version()
if "ir_version" not in kwargs:
kwargs["ir_version"] = select_ir_version(opsets[""])
opset_imports = [
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
]
return helper.make_model(
graph, opset_imports=opset_imports, functions=functions, **kwargs
)
def to_graph_and_functions(
self,
use_default_type: bool = True,
value_infos: Sequence[ValueInfoProto] | None = None,
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
"""Converts this instance into a `onnx.GraphProto` and a map from
function-name to `onnx.FunctionProto`.
Args:
use_default_type: if True, the function uses a default type
for inputs and outputs that do not have a type
value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added
to the graph.
Returns:
a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto`
"""
called_functions: dict[str, onnx.FunctionProto] = {}
for s in self.stmts:
called_functions.update(s.functions)
called_functions.update(self.called_functions)
graph = helper.make_graph(
[s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)],
self.name,
[x.to_value_info(use_default_type) for x in self.inputs],
[y.to_value_info(use_default_type) for y in self.outputs],
value_info=value_infos,
)
return graph, called_functions
def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
"""Converts this instance into a `onnx.GraphProto`.
Args:
use_default_type: if True, the function uses a default type
for inputs and outputs that do not have a type
Returns:
an instance of :class:`onnx.GraphProto`
"""
graph, _ = self.to_graph_and_functions(use_default_type=use_default_type)
return graph
def get_opset_import(self) -> dict[str, int]:
func_opset_imports = {}
for s in self.stmts:
if s.callee.opset.domain not in func_opset_imports:
func_opset_imports[s.callee.opset.domain] = s.callee.opset.version
elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version:
warnings.warn(
f"There is a version conflict in domain: {s.callee.opset.domain!r}, "
f"with {self.name!r}.",
category=UserWarning,
stacklevel=1,
)
return func_opset_imports
def to_function_proto(self) -> onnx.FunctionProto:
"""Converts this instance into a `onnx.FunctionProto`.
Note: Default values for attributes are an experimental feature in ONNX.
Conversion ignores default values for attributes if the ONNX version installed
doesn't support it.
"""
opsets = self.get_opset_import()
nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)]
for n in nodes:
if n.domain not in opsets:
opsets[n.domain] = 1 # TODO: how to get n.version?
opset_imports = [
onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items()
]
attribute_names = [attr.name for attr in self.attrs if not attr.has_default]
f = helper.make_function(
self.domain,
self.name,
inputs=[x.name for x in self.inputs],
outputs=[y.name for y in self.outputs],
nodes=nodes,
opset_imports=opset_imports, # TODO
attributes=attribute_names,
doc_string=self.docstring,
)
# In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead
if hasattr(f, "attribute_proto"):
f.attribute_proto.extend(
[attr.attr_proto for attr in self.attrs if attr.has_default]
)
return f
# IRBuilder: abstracts out details of the IR in the python-to-IR converter
class IRBuilder:
def __init__(self):
self.functions = {}
def new_function(self, name: str, domain: str = "", register: bool = False) -> IRFunction:
if register and (domain, name) in self.functions:
raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.")
function = IRFunction(name, domain)
if register:
self.functions[domain, name] = function
return function
def add_docstring(self, fn: IRFunction, docstring: str):
fn.append_docstring(docstring)
def add_stmt(
self,
fn: IRFunction,
results: Sequence[str],
callee: values.Op,
args: Sequence[Optional[str]],
attrs: Sequence[IRAttributeValue],
sub_functions=None,
) -> None:
stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions)
fn.append_stmt(stmt)
def add_input(
self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo
) -> None:
var = IRVar(varname, type, info)
fn.append_input(var)
def add_attr_parameter(
self,
fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
default_value: int | float | str | None,
) -> None:
fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value))
def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None:
var = IRVar(varname, typeinfo, sourceinfo)
fn.append_output(var)
def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue:
return IRAttributeValue(attrproto)
def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue:
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
attr_type = ta.pytype_to_attrtype(pytype)
assert attr_type is not None
proto.type = attr_type
return IRAttributeValue(proto)