xiaoanyu123's picture
Add files using upload-large-folder tool
6a22ec9 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import numbers
from typing import Optional, Sequence
import numpy as np
import onnx
from onnxscript import tensor
def external_tensor(
name: str,
data_type: int,
dims: Sequence[int],
location: str,
offset: Optional[int] = None,
length: Optional[int] = None,
checksum: Optional[str] = None,
basepath: Optional[str] = None,
) -> onnx.TensorProto:
"""Create a TensorProto referencing externally stored tensor-data.
Args:
name: name of the tensor
data_type: data type of tensor element
dims: shape of the tensor
location: location of the external file (relative path)
offset: offset in the file where the tensor-data starts
length: number of bytes containing the data
checksum: SHA1 digest of the file
basepath: basepath combined with location to form the full path
Returns:
TensorProto
See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md for more details.
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = data_type
tensor_proto.dims.extend(dims)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
def add(k, v):
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
add("location", location)
if offset is not None:
add("offset", int(offset))
if length is not None:
add("length", int(length))
if checksum is not None:
add("checksum", checksum)
if basepath is not None:
add("basepath", basepath)
return tensor_proto
def value_to_type_proto(val):
"""Return the ONNX type of a python-value."""
if isinstance(val, (np.ndarray, tensor.Tensor)):
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
shape = val.shape
return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
if isinstance(val, int):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
if isinstance(val, (float, np.float32)):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251
if isinstance(val, list):
if len(val) > 0:
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
return onnx.helper.make_sequence_type_proto( # noqa: TID251
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
def values_to_value_infos(name_values):
"""Create a list of ValueInfoProto from a list of (name, value) pairs,
skipping any None values.
"""
return [
onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
for (name, val) in name_values
if val is not None
]