xiaoanyu123's picture
Add files using upload-large-folder tool
6a22ec9 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Function for manipulating input parameters of an Op or a OnnxFunction."""
from __future__ import annotations
import collections
from typing import Any, OrderedDict, Sequence
from onnxscript import values
def separate_input_attributes_from_arguments(
param_schemas: Sequence[values.ParamSchema],
args,
kwargs,
fill_defaults: bool = True,
allow_extra_kwargs: bool = False,
) -> tuple[list[Any], OrderedDict[str, Any]]:
"""Separate Python args and kwargs into ONNX inputs and attributes.
Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
allow_extra_kwargs: Whether to allow extra keyword arguments.
When set to True, extra/unknown arguments will be ignored.
Returns:
A tuple of two elements:
- A list of ONNX inputs.
- An ordered dictionary of ONNX attribute names and values.
Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# user may not specify all inputs or attributes
all_param_names = {param.name for param in param_schemas}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")
onnx_inputs = []
onnx_attributes = collections.OrderedDict()
for i, param in enumerate(param_schemas):
if param.is_variadic_input:
# Exhaust all remaining args
onnx_inputs.extend(args[i:])
args = []
continue
if i < len(args):
if param.is_input:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in kwargs:
if param.is_input:
onnx_inputs.append(kwargs[param.name])
else:
onnx_attributes[param.name] = kwargs[param.name]
elif (
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
):
# User did not provide the attribute
if fill_defaults:
onnx_attributes[param.name] = param.default
elif param.required:
raise TypeError(f"Required input '{param}' was not provided")
return onnx_inputs, onnx_attributes
def tag_arguments_with_param_schemas(
param_schemas: Sequence[values.ParamSchema],
args,
kwargs,
fill_defaults: bool = True,
allow_extra_kwargs: bool = False,
) -> tuple[list[tuple[Any, values.ParamSchema]], dict[str, tuple[Any, values.ParamSchema]]]:
"""Tag Python args and kwargs with matching ONNX ParamSchema.
Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
allow_extra_kwargs: Whether to allow extra keyword arguments.
When set to True, extra/unknown arguments will be ignored.
Returns:
A tuple of two elements:
- A list of tuple of Python positional argument and ParamSchema.
- An ordered dictionary of Python keyword argument names and tuple of argument
value and ParamSchema.
Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# user may not specify all inputs or attributes
all_param_names = {param.name for param in param_schemas}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")
tagged_args: list[tuple[Any, values.ParamSchema]] = []
tagged_kwargs: dict[str, tuple[Any, values.ParamSchema]] = {}
for i, param in enumerate(param_schemas):
if param.is_variadic_input:
# Exhaust all remaining args
tagged_args.extend((arg, param) for arg in args[i:])
args = []
continue
if i < len(args):
tagged_args.append((args[i], param))
elif param.name in kwargs:
tagged_kwargs[param.name] = (kwargs[param.name], param)
elif param.default is not values._EmptyDefault: # pylint: disable=protected-access
# User did not provide the input/attribute
if fill_defaults:
tagged_kwargs[param.name] = (param.default, param)
elif param.required:
raise TypeError(f"Required input/attribute '{param}' was not provided")
return tagged_args, tagged_kwargs
def turn_to_kwargs_to_avoid_ordering(
param_schemas: Sequence[values.ParamSchema],
inputs: list[Any],
attributes: dict[str, Any],
) -> dict[str, Any]:
"""Return the inputs and attributes to the order of the function signature."""
for idx, param in enumerate(param_schemas):
if param.name not in attributes:
if param.is_variadic_input:
attributes[param.name] = inputs[idx:]
elif inputs:
attributes[param.name] = inputs.pop(0)
return attributes