File size: 3,490 Bytes
6a22ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
import numbers
from typing import Optional, Sequence
import numpy as np
import onnx
from onnxscript import tensor
def external_tensor(
name: str,
data_type: int,
dims: Sequence[int],
location: str,
offset: Optional[int] = None,
length: Optional[int] = None,
checksum: Optional[str] = None,
basepath: Optional[str] = None,
) -> onnx.TensorProto:
"""Create a TensorProto referencing externally stored tensor-data.
Args:
name: name of the tensor
data_type: data type of tensor element
dims: shape of the tensor
location: location of the external file (relative path)
offset: offset in the file where the tensor-data starts
length: number of bytes containing the data
checksum: SHA1 digest of the file
basepath: basepath combined with location to form the full path
Returns:
TensorProto
See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md for more details.
"""
tensor_proto = onnx.TensorProto()
tensor_proto.name = name
tensor_proto.data_type = data_type
tensor_proto.dims.extend(dims)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
def add(k, v):
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
add("location", location)
if offset is not None:
add("offset", int(offset))
if length is not None:
add("length", int(length))
if checksum is not None:
add("checksum", checksum)
if basepath is not None:
add("basepath", basepath)
return tensor_proto
def value_to_type_proto(val):
"""Return the ONNX type of a python-value."""
if isinstance(val, (np.ndarray, tensor.Tensor)):
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
shape = val.shape
return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
if isinstance(val, int):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
if isinstance(val, (float, np.float32)):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251
if isinstance(val, list):
if len(val) > 0:
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
return onnx.helper.make_sequence_type_proto( # noqa: TID251
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
def values_to_value_infos(name_values):
"""Create a list of ValueInfoProto from a list of (name, value) pairs,
skipping any None values.
"""
return [
onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
for (name, val) in name_values
if val is not None
]
|