File size: 3,490 Bytes
6a22ec9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import numbers
from typing import Optional, Sequence

import numpy as np
import onnx

from onnxscript import tensor


def external_tensor(
    name: str,
    data_type: int,
    dims: Sequence[int],
    location: str,
    offset: Optional[int] = None,
    length: Optional[int] = None,
    checksum: Optional[str] = None,
    basepath: Optional[str] = None,
) -> onnx.TensorProto:
    """Create a TensorProto referencing externally stored tensor-data.

    Args:
        name: name of the tensor
        data_type: data type of tensor element
        dims: shape of the tensor
        location: location of the external file (relative path)
        offset: offset in the file where the tensor-data starts
        length: number of bytes containing the data
        checksum: SHA1 digest of the file
        basepath: basepath combined with location to form the full path

    Returns:
        TensorProto

    See https://github.com/onnx/onnx/blob/main/docs/ExternalData.md for more details.
    """
    tensor_proto = onnx.TensorProto()
    tensor_proto.name = name
    tensor_proto.data_type = data_type
    tensor_proto.dims.extend(dims)
    tensor_proto.data_location = onnx.TensorProto.EXTERNAL

    def add(k, v):
        entry = tensor_proto.external_data.add()
        entry.key = k
        entry.value = str(v)

    add("location", location)
    if offset is not None:
        add("offset", int(offset))
    if length is not None:
        add("length", int(length))
    if checksum is not None:
        add("checksum", checksum)
    if basepath is not None:
        add("basepath", basepath)
    return tensor_proto


def value_to_type_proto(val):
    """Return the ONNX type of a python-value."""
    if isinstance(val, (np.ndarray, tensor.Tensor)):
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)  # noqa: TID251
        shape = val.shape
        return onnx.helper.make_tensor_type_proto(elem_type, shape)  # noqa: TID251
    if isinstance(val, int):
        return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])  # noqa: TID251
    if isinstance(val, (float, np.float32)):
        return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])  # noqa: TID251
    if isinstance(val, list):
        if len(val) > 0:
            return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))  # noqa: TID251
        # Edge-case. Cannot determine a suitable ONNX type for an empty list.
        # Should be using a typed-value instead.
        # Treated as a sequence of tensors of float-type.
        return onnx.helper.make_sequence_type_proto(  # noqa: TID251
            onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)  # noqa: TID251
        )
    if isinstance(val, numbers.Number):
        nparray = np.array(val)
        elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)  # noqa: TID251
        return onnx.helper.make_tensor_type_proto(elem_type, [])  # noqa: TID251
    raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


def values_to_value_infos(name_values):
    """Create a list of ValueInfoProto from a list of (name, value) pairs,
    skipping any None values.
    """
    return [
        onnx.helper.make_value_info(name, value_to_type_proto(val))  # noqa: TID251
        for (name, val) in name_values
        if val is not None
    ]