|
|
|
|
|
|
|
|
"""Function for manipulating input parameters of an Op or a OnnxFunction.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import collections |
|
|
from typing import Any, OrderedDict, Sequence |
|
|
|
|
|
from onnxscript import values |
|
|
|
|
|
|
|
|
def separate_input_attributes_from_arguments( |
|
|
param_schemas: Sequence[values.ParamSchema], |
|
|
args, |
|
|
kwargs, |
|
|
fill_defaults: bool = True, |
|
|
allow_extra_kwargs: bool = False, |
|
|
) -> tuple[list[Any], OrderedDict[str, Any]]: |
|
|
"""Separate Python args and kwargs into ONNX inputs and attributes. |
|
|
|
|
|
Args: |
|
|
param_schemas: The parameter schemas of an Op or a OnnxFunction. |
|
|
args: The Python positional arguments supplied by the caller. |
|
|
kwargs: The Python keyword arguments supplied by the caller. |
|
|
fill_defaults: Whether to fill the default values for attributes. |
|
|
allow_extra_kwargs: Whether to allow extra keyword arguments. |
|
|
When set to True, extra/unknown arguments will be ignored. |
|
|
|
|
|
Returns: |
|
|
A tuple of two elements: |
|
|
- A list of ONNX inputs. |
|
|
- An ordered dictionary of ONNX attribute names and values. |
|
|
|
|
|
Raises: |
|
|
TypeError: When allow_extra_kwargs is False and there are unknown kwargs. |
|
|
TypeError: When a required input is not provided. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
all_param_names = {param.name for param in param_schemas} |
|
|
extra_kwargs = set(kwargs).difference(all_param_names) |
|
|
if extra_kwargs and not allow_extra_kwargs: |
|
|
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'") |
|
|
|
|
|
onnx_inputs = [] |
|
|
onnx_attributes = collections.OrderedDict() |
|
|
|
|
|
for i, param in enumerate(param_schemas): |
|
|
if param.is_variadic_input: |
|
|
|
|
|
onnx_inputs.extend(args[i:]) |
|
|
args = [] |
|
|
continue |
|
|
if i < len(args): |
|
|
if param.is_input: |
|
|
onnx_inputs.append(args[i]) |
|
|
else: |
|
|
onnx_attributes[param.name] = args[i] |
|
|
elif param.name in kwargs: |
|
|
if param.is_input: |
|
|
onnx_inputs.append(kwargs[param.name]) |
|
|
else: |
|
|
onnx_attributes[param.name] = kwargs[param.name] |
|
|
elif ( |
|
|
param.is_attribute and param.default is not values._EmptyDefault |
|
|
): |
|
|
|
|
|
if fill_defaults: |
|
|
onnx_attributes[param.name] = param.default |
|
|
elif param.required: |
|
|
raise TypeError(f"Required input '{param}' was not provided") |
|
|
|
|
|
return onnx_inputs, onnx_attributes |
|
|
|
|
|
|
|
|
def tag_arguments_with_param_schemas( |
|
|
param_schemas: Sequence[values.ParamSchema], |
|
|
args, |
|
|
kwargs, |
|
|
fill_defaults: bool = True, |
|
|
allow_extra_kwargs: bool = False, |
|
|
) -> tuple[list[tuple[Any, values.ParamSchema]], dict[str, tuple[Any, values.ParamSchema]]]: |
|
|
"""Tag Python args and kwargs with matching ONNX ParamSchema. |
|
|
|
|
|
Args: |
|
|
param_schemas: The parameter schemas of an Op or a OnnxFunction. |
|
|
args: The Python positional arguments supplied by the caller. |
|
|
kwargs: The Python keyword arguments supplied by the caller. |
|
|
fill_defaults: Whether to fill the default values for attributes. |
|
|
allow_extra_kwargs: Whether to allow extra keyword arguments. |
|
|
When set to True, extra/unknown arguments will be ignored. |
|
|
|
|
|
Returns: |
|
|
A tuple of two elements: |
|
|
- A list of tuple of Python positional argument and ParamSchema. |
|
|
- An ordered dictionary of Python keyword argument names and tuple of argument |
|
|
value and ParamSchema. |
|
|
|
|
|
Raises: |
|
|
TypeError: When allow_extra_kwargs is False and there are unknown kwargs. |
|
|
TypeError: When a required input is not provided. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
all_param_names = {param.name for param in param_schemas} |
|
|
extra_kwargs = set(kwargs).difference(all_param_names) |
|
|
if extra_kwargs and not allow_extra_kwargs: |
|
|
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'") |
|
|
|
|
|
tagged_args: list[tuple[Any, values.ParamSchema]] = [] |
|
|
tagged_kwargs: dict[str, tuple[Any, values.ParamSchema]] = {} |
|
|
|
|
|
for i, param in enumerate(param_schemas): |
|
|
if param.is_variadic_input: |
|
|
|
|
|
tagged_args.extend((arg, param) for arg in args[i:]) |
|
|
args = [] |
|
|
continue |
|
|
if i < len(args): |
|
|
tagged_args.append((args[i], param)) |
|
|
elif param.name in kwargs: |
|
|
tagged_kwargs[param.name] = (kwargs[param.name], param) |
|
|
elif param.default is not values._EmptyDefault: |
|
|
|
|
|
if fill_defaults: |
|
|
tagged_kwargs[param.name] = (param.default, param) |
|
|
elif param.required: |
|
|
raise TypeError(f"Required input/attribute '{param}' was not provided") |
|
|
|
|
|
return tagged_args, tagged_kwargs |
|
|
|
|
|
|
|
|
def turn_to_kwargs_to_avoid_ordering( |
|
|
param_schemas: Sequence[values.ParamSchema], |
|
|
inputs: list[Any], |
|
|
attributes: dict[str, Any], |
|
|
) -> dict[str, Any]: |
|
|
"""Return the inputs and attributes to the order of the function signature.""" |
|
|
for idx, param in enumerate(param_schemas): |
|
|
if param.name not in attributes: |
|
|
if param.is_variadic_input: |
|
|
attributes[param.name] = inputs[idx:] |
|
|
elif inputs: |
|
|
attributes[param.name] = inputs.pop(0) |
|
|
return attributes |
|
|
|