csukuangfj commited on
Commit
f2fcebd
·
1 Parent(s): 17019b5

add 1.17.1

Browse files
v1.17.1/headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
v1.17.1/headers/nnapi_provider_factory.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // NNAPIFlags are bool options we want to set for NNAPI EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
10
+ // uint32_t nnapi_flags = 0;
11
+ // nnapi_flags |= NNAPI_FLAG_USE_FP16;
12
+ enum NNAPIFlags {
13
+ NNAPI_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
16
+ NNAPI_FLAG_USE_FP16 = 0x001,
17
+
18
+ // Use NCHW layout in NNAPI EP, this is only available after Android API level 29
19
+ // Please note for now, NNAPI perform worse using NCHW compare to using NHWC
20
+ NNAPI_FLAG_USE_NCHW = 0x002,
21
+
22
+ // Prevent NNAPI from using CPU devices.
23
+ //
24
+ // NNAPI is more efficient using GPU or NPU for execution, and NNAPI might fall back to its own CPU implementation
25
+ // for operations not supported by GPU/NPU. The CPU implementation of NNAPI (which is called nnapi-reference)
26
+ // might be less efficient than the optimized versions of the operation of ORT. It might be advantageous to disable
27
+ // the NNAPI CPU fallback and handle execution using ORT kernels.
28
+ //
29
+ // For some models, if NNAPI would use CPU to execute an operation, and this flag is set, the execution of the
30
+ // model may fall back to ORT kernels.
31
+ //
32
+ // This option is only available after Android API level 29, and will be ignored for Android API level 28-
33
+ //
34
+ // For NNAPI device assignments, see https://developer.android.com/ndk/guides/neuralnetworks#device-assignment
35
+ // For NNAPI CPU fallback, see https://developer.android.com/ndk/guides/neuralnetworks#cpu-fallback
36
+ //
37
+ // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
38
+ // and NNAPI_FLAG_CPU_ONLY flags are set
39
+ NNAPI_FLAG_CPU_DISABLED = 0x004,
40
+
41
+ // Using CPU only in NNAPI EP, this may decrease the perf but will provide
42
+ // reference output value without precision loss, which is useful for validation
43
+ //
44
+ // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
45
+ // and NNAPI_FLAG_CPU_ONLY flags are set
46
+ NNAPI_FLAG_CPU_ONLY = 0x008,
47
+
48
+ // Keep NNAPI_FLAG_LAST at the end of the enum definition
49
+ // And assign the last NNAPIFlag to it
50
+ NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY,
51
+ };
52
+
53
+ #ifdef __cplusplus
54
+ extern "C" {
55
+ #endif
56
+
57
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi,
58
+ _In_ OrtSessionOptions* options, uint32_t nnapi_flags);
59
+
60
+ #ifdef __cplusplus
61
+ }
62
+ #endif
v1.17.1/headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
v1.17.1/headers/onnxruntime_cxx_api.h ADDED
The diff for this file is too large to render. See raw diff
 
v1.17.1/headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,2075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ #include <cstring>
11
+ #include <functional>
12
+
13
+ #define RETURN_ON_API_FAIL(expression) \
14
+ { \
15
+ auto err = (expression); \
16
+ if (err) { \
17
+ return Status(err); \
18
+ } \
19
+ }
20
+
21
+ namespace Ort {
22
+
23
+ namespace detail {
24
+ inline void ThrowStatus(const Status& st) {
25
+ std::string error_message = st.GetErrorMessage();
26
+ OrtErrorCode error_code = st.GetErrorCode();
27
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
28
+ }
29
+ } // namespace detail
30
+
31
+ inline void ThrowOnError(OrtStatus* ort_status) {
32
+ if (ort_status) {
33
+ Ort::Status st(ort_status);
34
+ detail::ThrowStatus(st);
35
+ }
36
+ }
37
+
38
+ inline void ThrowOnError(const Status& st) {
39
+ if (st) {
40
+ detail::ThrowStatus(st);
41
+ }
42
+ }
43
+
44
+ inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
45
+ }
46
+
47
+ inline Status::Status(const std::exception& e) noexcept {
48
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
49
+ }
50
+
51
+ inline Status::Status(const Exception& e) noexcept {
52
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
53
+ }
54
+
55
+ inline Status::Status(const char* message, OrtErrorCode code) noexcept {
56
+ p_ = GetApi().CreateStatus(code, message);
57
+ }
58
+
59
+ inline std::string Status::GetErrorMessage() const {
60
+ std::string message(GetApi().GetErrorMessage(p_));
61
+ return message;
62
+ }
63
+
64
+ inline OrtErrorCode Status::GetErrorCode() const {
65
+ return GetApi().GetErrorCode(p_);
66
+ }
67
+
68
+ inline bool Status::IsOK() const noexcept {
69
+ return (p_ == nullptr);
70
+ }
71
+
72
+ // This template converts a C++ type into it's ONNXTensorElementDataType
73
+ template <typename T>
74
+ struct TypeToTensorType;
75
+ template <>
76
+ struct TypeToTensorType<float> {
77
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
78
+ };
79
+ template <>
80
+ struct TypeToTensorType<Float16_t> {
81
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
82
+ };
83
+ template <>
84
+ struct TypeToTensorType<BFloat16_t> {
85
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
86
+ };
87
+ template <>
88
+ struct TypeToTensorType<double> {
89
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
90
+ };
91
+ template <>
92
+ struct TypeToTensorType<int8_t> {
93
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
94
+ };
95
+ template <>
96
+ struct TypeToTensorType<int16_t> {
97
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
98
+ };
99
+ template <>
100
+ struct TypeToTensorType<int32_t> {
101
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
102
+ };
103
+ template <>
104
+ struct TypeToTensorType<int64_t> {
105
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
106
+ };
107
+ template <>
108
+ struct TypeToTensorType<uint8_t> {
109
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
110
+ };
111
+ template <>
112
+ struct TypeToTensorType<uint16_t> {
113
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
114
+ };
115
+ template <>
116
+ struct TypeToTensorType<uint32_t> {
117
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
118
+ };
119
+ template <>
120
+ struct TypeToTensorType<uint64_t> {
121
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
122
+ };
123
+ template <>
124
+ struct TypeToTensorType<bool> {
125
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
126
+ };
127
+
128
+ template <>
129
+ struct TypeToTensorType<Float8E4M3FN_t> {
130
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
131
+ };
132
+ template <>
133
+ struct TypeToTensorType<Float8E4M3FNUZ_t> {
134
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
135
+ };
136
+ template <>
137
+ struct TypeToTensorType<Float8E5M2_t> {
138
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
139
+ };
140
+ template <>
141
+ struct TypeToTensorType<Float8E5M2FNUZ_t> {
142
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
143
+ };
144
+
145
+ inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
146
+ if (IsNaN() || rhs.IsNaN()) {
147
+ // IEEE defines that NaN is not equal to anything, including itself.
148
+ return false;
149
+ }
150
+ return val == rhs.val;
151
+ }
152
+
153
+ inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
154
+ if (IsNaN() || rhs.IsNaN()) {
155
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
156
+ return false;
157
+ }
158
+
159
+ const bool left_is_negative = IsNegative();
160
+ if (left_is_negative != rhs.IsNegative()) {
161
+ // When the signs of left and right differ, we know that left is less than right if it is
162
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
163
+ // says they should be equal, even if the signs differ.
164
+ return left_is_negative && !AreZero(*this, rhs);
165
+ }
166
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
167
+ }
168
+
169
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
170
+ : allocator_(allocator), p_(p), size_(size) {
171
+ }
172
+
173
+ inline MemoryAllocation::~MemoryAllocation() {
174
+ if (p_ != nullptr) {
175
+ // We do not throw out of destructor
176
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
177
+ static_cast<void>(ret);
178
+ }
179
+ }
180
+
181
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
182
+ *this = std::move(o);
183
+ }
184
+
185
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
186
+ OrtAllocator* alloc = nullptr;
187
+ void* p = nullptr;
188
+ size_t sz = 0;
189
+
190
+ // Swap out this
191
+ std::swap(alloc, allocator_);
192
+ std::swap(p, p_);
193
+ std::swap(sz, size_);
194
+
195
+ // Swap with incoming
196
+ std::swap(allocator_, o.allocator_);
197
+ std::swap(p_, o.p_);
198
+ std::swap(size_, o.size_);
199
+
200
+ // Destroy this instance if needed
201
+ MemoryAllocation this_alloc(alloc, p, sz);
202
+ return *this;
203
+ }
204
+
205
+ namespace detail {
206
+
207
+ template <typename T>
208
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
209
+ void* out;
210
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
211
+ return out;
212
+ }
213
+
214
+ template <typename T>
215
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
216
+ void* out;
217
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
218
+ MemoryAllocation result(this->p_, out, size);
219
+ return result;
220
+ }
221
+
222
+ template <typename T>
223
+ inline void AllocatorImpl<T>::Free(void* p) {
224
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
225
+ }
226
+
227
+ template <typename T>
228
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
229
+ const OrtMemoryInfo* out;
230
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
231
+ return ConstMemoryInfo{out};
232
+ }
233
+
234
+ } // namespace detail
235
+
236
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
237
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
238
+ }
239
+
240
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
241
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
242
+ }
243
+
244
+ namespace detail {
245
+
246
+ template <typename T>
247
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
248
+ const char* name = nullptr;
249
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
250
+ return std::string(name);
251
+ }
252
+
253
+ template <typename T>
254
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
255
+ OrtAllocatorType type;
256
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
257
+ return type;
258
+ }
259
+
260
+ template <typename T>
261
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
262
+ int id = 0;
263
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
264
+ return id;
265
+ }
266
+
267
+ template <typename T>
268
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
269
+ OrtMemoryInfoDeviceType type;
270
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
271
+ return type;
272
+ }
273
+
274
+ template <typename T>
275
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
276
+ OrtMemType type;
277
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
278
+ return type;
279
+ }
280
+
281
+ template <typename T>
282
+ template <typename U>
283
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
284
+ int comp_result = 0;
285
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
286
+ return comp_result == 0;
287
+ }
288
+
289
+ } // namespace detail
290
+
291
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
292
+ OrtMemoryInfo* p;
293
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
294
+ return MemoryInfo(p);
295
+ }
296
+
297
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
298
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
299
+ }
300
+
301
+ namespace detail {
302
+ template <typename T>
303
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
304
+ AllocatorWithDefaultOptions allocator;
305
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
306
+ }
307
+
308
+ template <typename T>
309
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
310
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
311
+ }
312
+
313
+ template <typename T>
314
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
315
+ AllocatorWithDefaultOptions allocator;
316
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
317
+ }
318
+
319
+ template <typename T>
320
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
321
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
322
+ }
323
+
324
+ template <typename T>
325
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
326
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
327
+ }
328
+
329
+ template <typename T>
330
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
331
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
332
+ }
333
+
334
+ template <typename T>
335
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
336
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
337
+ }
338
+
339
+ template <typename T>
340
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
341
+ GetApi().ClearBoundInputs(this->p_);
342
+ }
343
+
344
+ template <typename T>
345
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
346
+ GetApi().ClearBoundOutputs(this->p_);
347
+ }
348
+
349
+ template <typename T>
350
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
351
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
352
+ }
353
+
354
+ template <typename T>
355
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
356
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
357
+ }
358
+
359
+ namespace binding_utils {
360
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
361
+ std::vector<std::string> result;
362
+ auto free_fn = detail::AllocatedFree(allocator);
363
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
364
+
365
+ char* buffer = nullptr;
366
+ size_t* lengths = nullptr;
367
+ size_t count = 0;
368
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
369
+
370
+ if (count == 0) {
371
+ return result;
372
+ }
373
+
374
+ Ptr buffer_g(buffer, free_fn);
375
+ Ptr lengths_g(lengths, free_fn);
376
+
377
+ result.reserve(count);
378
+ for (size_t i = 0; i < count; ++i) {
379
+ auto sz = *lengths;
380
+ result.emplace_back(buffer, sz);
381
+ buffer += sz;
382
+ ++lengths;
383
+ }
384
+ return result;
385
+ }
386
+
387
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
388
+ std::vector<Value> result;
389
+ size_t owned = 0;
390
+ size_t output_count = 0;
391
+ // Lambda to release the buffer when no longer needed and
392
+ // make sure that we destroy all instances on exception
393
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
394
+ if (buffer) {
395
+ while (owned < output_count) {
396
+ auto* p = buffer + owned++;
397
+ GetApi().ReleaseValue(*p);
398
+ }
399
+ allocator->Free(allocator, buffer);
400
+ }
401
+ };
402
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
403
+
404
+ OrtValue** output_buffer = nullptr;
405
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
406
+ if (output_count == 0) {
407
+ return result;
408
+ }
409
+
410
+ Ptr buffer_g(output_buffer, free_fn);
411
+
412
+ result.reserve(output_count);
413
+ for (size_t i = 0; i < output_count; ++i) {
414
+ result.emplace_back(output_buffer[i]);
415
+ ++owned;
416
+ }
417
+ return result;
418
+ }
419
+
420
+ } // namespace binding_utils
421
+ } // namespace detail
422
+
423
+ inline IoBinding::IoBinding(Session& session) {
424
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
425
+ }
426
+
427
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
428
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
429
+ }
430
+
431
+ inline ThreadingOptions::ThreadingOptions() {
432
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
433
+ }
434
+
435
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
436
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
437
+ return *this;
438
+ }
439
+
440
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
441
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
442
+ return *this;
443
+ }
444
+
445
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
446
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
447
+ return *this;
448
+ }
449
+
450
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
451
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
452
+ return *this;
453
+ }
454
+
455
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
456
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
457
+ return *this;
458
+ }
459
+
460
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
461
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
462
+ return *this;
463
+ }
464
+
465
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
466
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
467
+ return *this;
468
+ }
469
+
470
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
471
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
472
+ if (strcmp(logid, "onnxruntime-node") == 0) {
473
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
474
+ } else {
475
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
476
+ }
477
+ }
478
+
479
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
480
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
481
+ if (strcmp(logid, "onnxruntime-node") == 0) {
482
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
483
+ } else {
484
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
485
+ }
486
+ }
487
+
488
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
489
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
490
+ if (strcmp(logid, "onnxruntime-node") == 0) {
491
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
492
+ } else {
493
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
494
+ }
495
+ }
496
+
497
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
498
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
499
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
500
+ if (strcmp(logid, "onnxruntime-node") == 0) {
501
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
502
+ } else {
503
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
504
+ }
505
+ }
506
+
507
+ inline Env& Env::EnableTelemetryEvents() {
508
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
509
+ return *this;
510
+ }
511
+
512
+ inline Env& Env::DisableTelemetryEvents() {
513
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
514
+ return *this;
515
+ }
516
+
517
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
518
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
519
+ return *this;
520
+ }
521
+
522
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
523
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
524
+ return *this;
525
+ }
526
+
527
+ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
528
+ std::vector<const char*> keys, values;
529
+ auto num_entries = options.size();
530
+ if (num_entries > 0) {
531
+ keys.reserve(num_entries);
532
+ values.reserve(num_entries);
533
+ for (const auto& entry : options) {
534
+ keys.push_back(entry.first.c_str());
535
+ values.push_back(entry.second.c_str());
536
+ }
537
+ }
538
+ ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
539
+ return *this;
540
+ }
541
+
542
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
543
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
544
+ }
545
+
546
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
547
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
548
+ }
549
+
550
+ inline RunOptions::RunOptions() {
551
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
552
+ }
553
+
554
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
555
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
556
+ return *this;
557
+ }
558
+
559
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
560
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
561
+ return *this;
562
+ }
563
+
564
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
565
+ int out;
566
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
567
+ return out;
568
+ }
569
+
570
+ inline int RunOptions::GetRunLogSeverityLevel() const {
571
+ int out;
572
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
573
+ return out;
574
+ }
575
+
576
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
577
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
578
+ return *this;
579
+ }
580
+
581
+ inline const char* RunOptions::GetRunTag() const {
582
+ const char* out;
583
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
584
+ return out;
585
+ }
586
+
587
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
588
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
589
+ return *this;
590
+ }
591
+
592
+ inline RunOptions& RunOptions::SetTerminate() {
593
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
594
+ return *this;
595
+ }
596
+
597
+ inline RunOptions& RunOptions::UnsetTerminate() {
598
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
599
+ return *this;
600
+ }
601
+
602
+ namespace detail {
603
+
604
+ template <typename T>
605
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
606
+ OrtSessionOptions* out;
607
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
608
+ return SessionOptions{out};
609
+ }
610
+
611
+ template <typename T>
612
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
613
+ size_t size = 0;
614
+ // Feed nullptr for the data buffer to query the true size of the string value
615
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
616
+
617
+ std::string out;
618
+ out.resize(size);
619
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
620
+ out.resize(size - 1); // remove the terminating character '\0'
621
+
622
+ return out;
623
+ }
624
+
625
+ template <typename T>
626
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
627
+ int out = 0;
628
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
629
+ return static_cast<bool>(out);
630
+ }
631
+
632
+ template <typename T>
633
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
634
+ if (!this->HasConfigEntry(config_key)) {
635
+ return def;
636
+ }
637
+
638
+ return this->GetConfigEntry(config_key);
639
+ }
640
+
641
+ template <typename T>
642
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
643
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
644
+ return *this;
645
+ }
646
+
647
+ template <typename T>
648
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
649
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
650
+ return *this;
651
+ }
652
+
653
+ template <typename T>
654
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
655
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
656
+ return *this;
657
+ }
658
+
659
+ template <typename T>
660
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetDeterministicCompute(bool value) {
661
+ ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
662
+ return *this;
663
+ }
664
+
665
+ template <typename T>
666
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
667
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
668
+ return *this;
669
+ }
670
+
671
+ template <typename T>
672
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
673
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
674
+ return *this;
675
+ }
676
+
677
+ template <typename T>
678
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
679
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
680
+ return *this;
681
+ }
682
+
683
+ template <typename T>
684
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
685
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
686
+ return *this;
687
+ }
688
+
689
+ template <typename T>
690
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
691
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
692
+ return *this;
693
+ }
694
+
695
+ template <typename T>
696
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
697
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
698
+ return *this;
699
+ }
700
+
701
+ template <typename T>
702
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
703
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
704
+ return *this;
705
+ }
706
+
707
+ template <typename T>
708
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
709
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
710
+ return *this;
711
+ }
712
+
713
+ template <typename T>
714
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
715
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
716
+ return *this;
717
+ }
718
+
719
+ template <typename T>
720
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
721
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
722
+ return *this;
723
+ }
724
+
725
+ template <typename T>
726
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
727
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
728
+ return *this;
729
+ }
730
+
731
+ template <typename T>
732
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
733
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
734
+ return *this;
735
+ }
736
+
737
+ template <typename T>
738
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
739
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
740
+ return *this;
741
+ }
742
+
743
+ template <typename T>
744
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
745
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
746
+ return *this;
747
+ }
748
+
749
+ template <typename T>
750
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
751
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
752
+ return *this;
753
+ }
754
+
755
+ template <typename T>
756
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
757
+ const std::vector<Value>& ort_values) {
758
+ const size_t inputs_num = names.size();
759
+ if (inputs_num != ort_values.size()) {
760
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
761
+ }
762
+ std::vector<const char*> names_ptr;
763
+ std::vector<const OrtValue*> ort_values_ptrs;
764
+ names_ptr.reserve(inputs_num);
765
+ ort_values_ptrs.reserve(inputs_num);
766
+ for (size_t i = 0; i < inputs_num; ++i) {
767
+ names_ptr.push_back(names[i].c_str());
768
+ ort_values_ptrs.push_back(ort_values[i]);
769
+ }
770
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
771
+ return *this;
772
+ }
773
+
774
+ template <typename T>
775
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
776
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
777
+ return *this;
778
+ }
779
+
780
+ template <typename T>
781
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
782
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
783
+ return *this;
784
+ }
785
+
786
+ template <typename T>
787
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
788
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
789
+ return *this;
790
+ }
791
+
792
+ template <typename T>
793
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
794
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
795
+ return *this;
796
+ }
797
+
798
+ template <typename T>
799
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
800
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
801
+ return *this;
802
+ }
803
+
804
+ template <typename T>
805
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
806
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
807
+ return *this;
808
+ }
809
+
810
+ template <typename T>
811
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
812
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
813
+ return *this;
814
+ }
815
+
816
+ template <typename T>
817
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
818
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
819
+ return *this;
820
+ }
821
+
822
+ template <typename T>
823
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
824
+ const std::string& provider_name,
825
+ const std::unordered_map<std::string, std::string>& provider_options) {
826
+ auto num_entries = provider_options.size();
827
+ std::vector<const char*> keys, values;
828
+ if (num_entries > 0) {
829
+ keys.reserve(num_entries);
830
+ values.reserve(num_entries);
831
+
832
+ for (const auto& entry : provider_options) {
833
+ keys.push_back(entry.first.c_str());
834
+ values.push_back(entry.second.c_str());
835
+ }
836
+ }
837
+
838
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
839
+ keys.data(), values.data(), num_entries));
840
+
841
+ return *this;
842
+ }
843
+
844
+ template <typename T>
845
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
846
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
847
+ return *this;
848
+ }
849
+
850
+ template <typename T>
851
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
852
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
853
+ return *this;
854
+ }
855
+
856
+ template <typename T>
857
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
858
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
859
+ return *this;
860
+ }
861
+
862
+ template <typename T>
863
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
864
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
865
+ return *this;
866
+ }
867
+
868
+ template <typename T>
869
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
870
+ auto num_entries = provider_options.size();
871
+ std::vector<const char*> keys, values;
872
+ if (num_entries > 0) {
873
+ keys.reserve(num_entries);
874
+ values.reserve(num_entries);
875
+
876
+ for (const auto& entry : provider_options) {
877
+ keys.push_back(entry.first.c_str());
878
+ values.push_back(entry.second.c_str());
879
+ }
880
+ }
881
+
882
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
883
+ keys.data(), values.data(), num_entries));
884
+
885
+ return *this;
886
+ }
887
+
888
+ template <typename T>
889
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
890
+ const CustomOpConfigs& custom_op_configs) {
891
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
892
+ // the custom op library.
893
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
894
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
895
+ }
896
+
897
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
898
+ return *this;
899
+ }
900
+
901
+ template <typename T>
902
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
903
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
904
+ return *this;
905
+ }
906
+
907
+ /// Session
908
+ template <typename T>
909
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
910
+ size_t out;
911
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
912
+ return out;
913
+ }
914
+
915
+ template <typename T>
916
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
917
+ size_t out;
918
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
919
+ return out;
920
+ }
921
+
922
+ template <typename T>
923
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
924
+ size_t out;
925
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
926
+ return out;
927
+ }
928
+
929
+ template <typename T>
930
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
931
+ char* out;
932
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
933
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
934
+ }
935
+
936
+ template <typename T>
937
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
938
+ char* out;
939
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
940
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
941
+ }
942
+
943
+ template <typename T>
944
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
945
+ char* out;
946
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
947
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
948
+ }
949
+
950
+ template <typename T>
951
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
952
+ uint64_t out;
953
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
954
+ return out;
955
+ }
956
+
957
+ template <typename T>
958
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
959
+ OrtModelMetadata* out;
960
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
961
+ return ModelMetadata{out};
962
+ }
963
+
964
+ template <typename T>
965
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
966
+ OrtTypeInfo* out;
967
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
968
+ return TypeInfo{out};
969
+ }
970
+
971
+ template <typename T>
972
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
973
+ OrtTypeInfo* out;
974
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
975
+ return TypeInfo{out};
976
+ }
977
+
978
+ template <typename T>
979
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
980
+ OrtTypeInfo* out;
981
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
982
+ return TypeInfo{out};
983
+ }
984
+
985
+ template <typename T>
986
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
987
+ const char* const* output_names, size_t output_count) {
988
+ std::vector<Value> output_values;
989
+ output_values.reserve(output_count);
990
+ for (size_t i = 0; i < output_count; i++)
991
+ output_values.emplace_back(nullptr);
992
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
993
+ return output_values;
994
+ }
995
+
996
+ template <typename T>
997
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
998
+ const char* const* output_names, Value* output_values, size_t output_count) {
999
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1000
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1001
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1002
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
1003
+ }
1004
+
1005
+ template <typename T>
1006
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
1007
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
1008
+ }
1009
+
1010
+ template <typename T>
1011
+ inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1012
+ const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
1013
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1014
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1015
+ ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
1016
+ ort_input_values, input_count, output_names, output_count,
1017
+ ort_output_values, callback, user_data));
1018
+ }
1019
+
1020
+ template <typename T>
1021
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
1022
+ char* out = nullptr;
1023
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
1024
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1025
+ }
1026
+
1027
+ } // namespace detail
1028
+
1029
+ inline SessionOptions::SessionOptions() {
1030
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
1031
+ }
1032
+
1033
+ /// CustomOpConfigs
1034
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
1035
+ std::string config_key = "custom_op.";
1036
+
1037
+ config_key += custom_op_name;
1038
+ config_key += ".";
1039
+ config_key += config;
1040
+
1041
+ return config_key;
1042
+ }
1043
+
1044
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
1045
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
1046
+ flat_configs_[full_flat_key] = config_value;
1047
+ return *this;
1048
+ }
1049
+
1050
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
1051
+ return flat_configs_;
1052
+ }
1053
+
1054
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
1055
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
1056
+ }
1057
+
1058
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1059
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
1060
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
1061
+ }
1062
+
1063
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
1064
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
1065
+ }
1066
+
1067
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
1068
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
1069
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
1070
+ prepacked_weights_container, &this->p_));
1071
+ }
1072
+
1073
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1074
+ char* out;
1075
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
1076
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1077
+ }
1078
+
1079
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
1080
+ char* out;
1081
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
1082
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1083
+ }
1084
+
1085
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
1086
+ char* out;
1087
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
1088
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1089
+ }
1090
+
1091
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
1092
+ char* out;
1093
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
1094
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1095
+ }
1096
+
1097
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
1098
+ char* out;
1099
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
1100
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1101
+ }
1102
+
1103
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1104
+ char* out;
1105
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1106
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1107
+ }
1108
+
1109
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1110
+ auto deletor = detail::AllocatedFree(allocator);
1111
+ std::vector<AllocatedStringPtr> result;
1112
+
1113
+ char** out = nullptr;
1114
+ int64_t num_keys = 0;
1115
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1116
+ if (num_keys <= 0) {
1117
+ return result;
1118
+ }
1119
+
1120
+ // array of pointers will be freed
1121
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1122
+ // reserve may throw
1123
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1124
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1125
+ result.reserve(static_cast<size_t>(num_keys));
1126
+ strings_guard.release();
1127
+ for (int64_t i = 0; i < num_keys; ++i) {
1128
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1129
+ }
1130
+
1131
+ return result;
1132
+ }
1133
+
1134
+ inline int64_t ModelMetadata::GetVersion() const {
1135
+ int64_t out;
1136
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1137
+ return out;
1138
+ }
1139
+
1140
+ namespace detail {
1141
+
1142
+ template <typename T>
1143
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1144
+ ONNXTensorElementDataType out;
1145
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1146
+ return out;
1147
+ }
1148
+
1149
+ template <typename T>
1150
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1151
+ size_t out;
1152
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1153
+ return static_cast<size_t>(out);
1154
+ }
1155
+
1156
+ template <typename T>
1157
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1158
+ size_t out;
1159
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1160
+ return out;
1161
+ }
1162
+
1163
+ template <typename T>
1164
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1165
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1166
+ }
1167
+
1168
+ template <typename T>
1169
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1170
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1171
+ }
1172
+
1173
+ template <typename T>
1174
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1175
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1176
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1177
+ return out;
1178
+ }
1179
+
1180
+ template <typename T>
1181
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1182
+ const OrtTensorTypeAndShapeInfo* out;
1183
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1184
+ return ConstTensorTypeAndShapeInfo{out};
1185
+ }
1186
+
1187
+ template <typename T>
1188
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1189
+ const OrtSequenceTypeInfo* out;
1190
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1191
+ return ConstSequenceTypeInfo{out};
1192
+ }
1193
+
1194
+ template <typename T>
1195
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1196
+ const OrtMapTypeInfo* out;
1197
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1198
+ return ConstMapTypeInfo{out};
1199
+ }
1200
+
1201
+ template <typename T>
1202
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1203
+ ONNXType out;
1204
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1205
+ return out;
1206
+ }
1207
+
1208
+ template <typename T>
1209
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1210
+ OrtTypeInfo* output;
1211
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1212
+ return TypeInfo{output};
1213
+ }
1214
+
1215
+ template <typename T>
1216
+ inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
1217
+ OrtTypeInfo* info;
1218
+ ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1219
+ return TypeInfo{info};
1220
+ }
1221
+
1222
+ template <typename T>
1223
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1224
+ ONNXTensorElementDataType out;
1225
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1226
+ return out;
1227
+ }
1228
+
1229
+ template <typename T>
1230
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1231
+ OrtTypeInfo* output;
1232
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1233
+ return TypeInfo{output};
1234
+ }
1235
+
1236
+ template <typename T>
1237
+ inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
1238
+ const OrtOptionalTypeInfo* info;
1239
+ ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1240
+ return ConstOptionalTypeInfo{info};
1241
+ }
1242
+
1243
+ } // namespace detail
1244
+
1245
+ namespace detail {
1246
+
1247
+ template <typename T>
1248
+ template <typename R>
1249
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1250
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1251
+ }
1252
+
1253
+ template <typename T>
1254
+ inline bool ConstValueImpl<T>::IsTensor() const {
1255
+ int out;
1256
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1257
+ return out != 0;
1258
+ }
1259
+
1260
+ template <typename T>
1261
+ inline bool ConstValueImpl<T>::HasValue() const {
1262
+ int out;
1263
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1264
+ return out != 0;
1265
+ }
1266
+
1267
+ template <typename T>
1268
+ inline size_t ConstValueImpl<T>::GetCount() const {
1269
+ size_t out;
1270
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1271
+ return out;
1272
+ }
1273
+
1274
+ template <typename T>
1275
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1276
+ OrtValue* out;
1277
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1278
+ return Value{out};
1279
+ }
1280
+
1281
+ template <typename T>
1282
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1283
+ size_t out;
1284
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1285
+ return out;
1286
+ }
1287
+
1288
+ template <typename T>
1289
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1290
+ size_t out;
1291
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1292
+ return out;
1293
+ }
1294
+
1295
+ template <typename T>
1296
+ template <typename R>
1297
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1298
+ R* out;
1299
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1300
+ return out;
1301
+ }
1302
+
1303
+ template <typename T>
1304
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1305
+ void* out;
1306
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1307
+ return out;
1308
+ }
1309
+
1310
+ template <typename T>
1311
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1312
+ OrtTypeInfo* output;
1313
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1314
+ return TypeInfo{output};
1315
+ }
1316
+
1317
+ template <typename T>
1318
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1319
+ OrtTensorTypeAndShapeInfo* output;
1320
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1321
+ return TensorTypeAndShapeInfo{output};
1322
+ }
1323
+
1324
+ template <typename T>
1325
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1326
+ const OrtMemoryInfo* mem_info;
1327
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1328
+ return ConstMemoryInfo(mem_info);
1329
+ }
1330
+
1331
+ template <typename T>
1332
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1333
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1334
+ }
1335
+
1336
+ template <typename T>
1337
+ inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1338
+ size_t buffer_length;
1339
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1340
+
1341
+ std::string s;
1342
+ s.resize(buffer_length);
1343
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1344
+ return s;
1345
+ }
1346
+
1347
+ template <typename T>
1348
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1349
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1350
+ }
1351
+
1352
+ #if !defined(DISABLE_SPARSE_TENSORS)
1353
+ template <typename T>
1354
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1355
+ OrtSparseFormat format;
1356
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1357
+ return format;
1358
+ }
1359
+
1360
+ template <typename T>
1361
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1362
+ OrtTensorTypeAndShapeInfo* output;
1363
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1364
+ return TensorTypeAndShapeInfo{output};
1365
+ }
1366
+
1367
+ template <typename T>
1368
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1369
+ OrtTensorTypeAndShapeInfo* output;
1370
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1371
+ return TensorTypeAndShapeInfo{output};
1372
+ }
1373
+
1374
+ template <typename T>
1375
+ template <typename R>
1376
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1377
+ const void* out;
1378
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1379
+ return reinterpret_cast<const R*>(out);
1380
+ }
1381
+
1382
+ template <typename T>
1383
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1384
+ int out;
1385
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1386
+ return out != 0;
1387
+ }
1388
+
1389
+ template <typename T>
1390
+ template <typename R>
1391
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1392
+ const void* out;
1393
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1394
+ return reinterpret_cast<const R*>(out);
1395
+ }
1396
+
1397
+ #endif
1398
+
1399
+ template <typename T>
1400
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1401
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1402
+ }
1403
+
1404
+ template <typename T>
1405
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1406
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1407
+ }
1408
+
1409
+ template <typename T>
1410
+ inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1411
+ char* result;
1412
+ ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1413
+ return result;
1414
+ }
1415
+
1416
+ template <typename T>
1417
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1418
+ void* out;
1419
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1420
+ return out;
1421
+ }
1422
+
1423
+ template <typename T>
1424
+ template <typename R>
1425
+ R* ValueImpl<T>::GetTensorMutableData() {
1426
+ R* out;
1427
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1428
+ return out;
1429
+ }
1430
+
1431
+ template <typename T>
1432
+ template <typename R>
1433
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1434
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1435
+ R* out;
1436
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1437
+ return *out;
1438
+ }
1439
+
1440
+ #if !defined(DISABLE_SPARSE_TENSORS)
1441
+ template <typename T>
1442
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1443
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1444
+ }
1445
+
1446
+ template <typename T>
1447
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1448
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1449
+ }
1450
+
1451
+ template <typename T>
1452
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1453
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1454
+ }
1455
+
1456
+ template <typename T>
1457
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1458
+ const int64_t* indices_data, size_t indices_num) {
1459
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1460
+ values_param.values_shape_len, values_param.data.p_data,
1461
+ indices_data, indices_num));
1462
+ }
1463
+
1464
+ template <typename T>
1465
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1466
+ const OrtSparseValuesParam& values,
1467
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1468
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1469
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1470
+ inner_indices_data, inner_indices_num,
1471
+ outer_indices_data, outer_indices_num));
1472
+ }
1473
+
1474
+ template <typename T>
1475
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1476
+ const OrtSparseValuesParam& values,
1477
+ const Shape& indices_shape,
1478
+ const int32_t* indices_data) {
1479
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1480
+ indices_shape.shape, indices_shape.shape_len,
1481
+ indices_data));
1482
+ }
1483
+
1484
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1485
+
1486
+ } // namespace detail
1487
+
1488
+ template <typename T>
1489
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1490
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1491
+ }
1492
+
1493
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1494
+ ONNXTensorElementDataType type) {
1495
+ OrtValue* out;
1496
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1497
+ return Value{out};
1498
+ }
1499
+
1500
+ template <typename T>
1501
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1502
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1503
+ }
1504
+
1505
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1506
+ OrtValue* out;
1507
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1508
+ return Value{out};
1509
+ }
1510
+
1511
+ #if !defined(DISABLE_SPARSE_TENSORS)
1512
+
1513
+ template <typename T>
1514
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1515
+ const Shape& values_shape) {
1516
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1517
+ }
1518
+
1519
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1520
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1521
+ OrtValue* out;
1522
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1523
+ values_shape.shape, values_shape.shape_len, type, &out));
1524
+ return Value{out};
1525
+ }
1526
+
1527
+ template <typename T>
1528
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1529
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1530
+ }
1531
+
1532
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1533
+ ONNXTensorElementDataType type) {
1534
+ OrtValue* out;
1535
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1536
+ return Value{out};
1537
+ }
1538
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1539
+
1540
+ inline Value Value::CreateMap(const Value& keys, const Value& values) {
1541
+ OrtValue* out;
1542
+ const OrtValue* inputs[2] = {keys, values};
1543
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1544
+ return Value{out};
1545
+ }
1546
+
1547
+ inline Value Value::CreateSequence(const std::vector<Value>& values) {
1548
+ OrtValue* out;
1549
+ std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
1550
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1551
+ return Value{out};
1552
+ }
1553
+
1554
+ template <typename T>
1555
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1556
+ OrtValue* out;
1557
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1558
+ return Value{out};
1559
+ }
1560
+
1561
+ //
1562
+ // Custom OP Inlines
1563
+ //
1564
+ inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1565
+ Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1566
+ }
1567
+
1568
+ inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1569
+ return cached_severity_level_;
1570
+ }
1571
+
1572
+ inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1573
+ const char* func_name, const char* message) const noexcept {
1574
+ OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1575
+ func_name);
1576
+ return Status{status};
1577
+ }
1578
+
1579
+ // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1580
+ // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1581
+ // __attribute__(format(printf...)), which does not work with variadic templates.
1582
+ #if defined(__GNUC__)
1583
+ #pragma GCC diagnostic push
1584
+ #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1585
+ #pragma GCC diagnostic ignored "-Wformat-security"
1586
+ #elif defined(__clang__)
1587
+ #pragma clang diagnostic push
1588
+ #pragma clang diagnostic ignored "-Wformat-nonliteral"
1589
+ #pragma clang diagnostic ignored "-Wformat-security"
1590
+ #endif
1591
+ template <typename... Args>
1592
+ inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1593
+ int line_number, const char* func_name, const char* format,
1594
+ Args&&... args) const noexcept {
1595
+ int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1596
+
1597
+ if (msg_len < 0) { // Formatting error
1598
+ return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1599
+ }
1600
+
1601
+ OrtStatus* status = nullptr;
1602
+ const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1603
+
1604
+ constexpr size_t kStackBufferSize = 1024;
1605
+
1606
+ if (buffer_size < kStackBufferSize) {
1607
+ char buffer[kStackBufferSize];
1608
+ snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1609
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1610
+ } else {
1611
+ // std::make_unique is only supported starting at C++14.
1612
+ #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1613
+ auto buffer = std::make_unique<char[]>(buffer_size);
1614
+ #else
1615
+ std::unique_ptr<char[]> buffer(new char[buffer_size]);
1616
+ #endif
1617
+ std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1618
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1619
+ }
1620
+
1621
+ return Status{status};
1622
+ }
1623
+ // Re-enable -Wformat-nonliteral and -Wformat-security
1624
+ #if defined(__GNUC__)
1625
+ #pragma GCC diagnostic pop
1626
+ #elif defined(__clang__)
1627
+ #pragma clang diagnostic pop
1628
+ #endif
1629
+
1630
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1631
+ }
1632
+
1633
+ inline size_t KernelContext::GetInputCount() const {
1634
+ size_t out = 0;
1635
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1636
+ return out;
1637
+ }
1638
+
1639
+ inline size_t KernelContext::GetOutputCount() const {
1640
+ size_t out = 0;
1641
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1642
+ return out;
1643
+ }
1644
+
1645
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1646
+ const OrtValue* out = nullptr;
1647
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1648
+ return ConstValue{out};
1649
+ }
1650
+
1651
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1652
+ OrtValue* out = nullptr;
1653
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1654
+ return UnownedValue(out);
1655
+ }
1656
+
1657
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1658
+ OrtValue* out = nullptr;
1659
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1660
+ return UnownedValue(out);
1661
+ }
1662
+
1663
+ inline void* KernelContext::GetGPUComputeStream() const {
1664
+ void* out = nullptr;
1665
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1666
+ return out;
1667
+ }
1668
+
1669
+ inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1670
+ OrtAllocator* out = nullptr;
1671
+ Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1672
+ return out;
1673
+ }
1674
+
1675
+ inline Logger KernelContext::GetLogger() const {
1676
+ const OrtLogger* out = nullptr;
1677
+ ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1678
+ return Logger{out};
1679
+ }
1680
+
1681
+ inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
1682
+ ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
1683
+ }
1684
+
1685
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1686
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1687
+ }
1688
+
1689
+ namespace detail {
1690
+ template <typename T>
1691
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1692
+ OrtKernelInfo* info_copy = nullptr;
1693
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1694
+ return KernelInfo{info_copy};
1695
+ }
1696
+
1697
+ template <typename T>
1698
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1699
+ size_t out = 0;
1700
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1701
+ return out;
1702
+ }
1703
+
1704
+ template <typename T>
1705
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1706
+ size_t out = 0;
1707
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1708
+ return out;
1709
+ }
1710
+
1711
+ template <typename T>
1712
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1713
+ size_t size = 0;
1714
+
1715
+ // Feed nullptr for the data buffer to query the true size of the string value
1716
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1717
+
1718
+ std::string out;
1719
+ out.resize(size);
1720
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1721
+ out.resize(size - 1); // remove the terminating character '\0'
1722
+
1723
+ return out;
1724
+ }
1725
+
1726
+ template <typename T>
1727
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1728
+ size_t size = 0;
1729
+
1730
+ // Feed nullptr for the data buffer to query the true size of the string value
1731
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1732
+
1733
+ std::string out;
1734
+ out.resize(size);
1735
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1736
+ out.resize(size - 1); // remove the terminating character '\0'
1737
+
1738
+ return out;
1739
+ }
1740
+
1741
+ template <typename T>
1742
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1743
+ OrtTypeInfo* out = nullptr;
1744
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1745
+ return TypeInfo{out};
1746
+ }
1747
+
1748
+ template <typename T>
1749
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1750
+ OrtTypeInfo* out = nullptr;
1751
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1752
+ return TypeInfo{out};
1753
+ }
1754
+
1755
+ template <typename T>
1756
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1757
+ OrtValue* out = nullptr;
1758
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1759
+ return Value{out};
1760
+ }
1761
+
1762
+ template <typename T>
1763
+ inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1764
+ const OrtValue* out = nullptr;
1765
+ ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1766
+ return ConstValue{out};
1767
+ }
1768
+
1769
+ template <typename T>
1770
+ inline std::string KernelInfoImpl<T>::GetNodeName() const {
1771
+ size_t size = 0;
1772
+
1773
+ // Feed nullptr for the data buffer to query the true size of the string value
1774
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1775
+
1776
+ std::string out;
1777
+ out.resize(size);
1778
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1779
+ out.resize(size - 1); // remove the terminating character '\0'
1780
+
1781
+ return out;
1782
+ }
1783
+
1784
+ template <typename T>
1785
+ inline Logger KernelInfoImpl<T>::GetLogger() const {
1786
+ const OrtLogger* out = nullptr;
1787
+ ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1788
+ return Logger{out};
1789
+ }
1790
+
1791
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1792
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1793
+ }
1794
+
1795
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1796
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1797
+ }
1798
+
1799
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1800
+ size_t size = 0;
1801
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1802
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1803
+
1804
+ std::string out;
1805
+ out.resize(size);
1806
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1807
+ out.resize(size - 1); // remove the terminating character '\0'
1808
+ out.swap(result);
1809
+ }
1810
+
1811
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1812
+ size_t size = 0;
1813
+ // Feed nullptr for the data buffer to query the true size of the attribute
1814
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1815
+
1816
+ std::vector<float> out;
1817
+ out.resize(size);
1818
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1819
+ out.swap(result);
1820
+ }
1821
+
1822
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1823
+ size_t size = 0;
1824
+
1825
+ // Feed nullptr for the data buffer to query the true size of the attribute
1826
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1827
+
1828
+ std::vector<int64_t> out;
1829
+ out.resize(size);
1830
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1831
+ out.swap(result);
1832
+ }
1833
+ } // namespace detail
1834
+
1835
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1836
+
1837
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1838
+
1839
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1840
+ const char** type_constraint_names,
1841
+ const ONNXTensorElementDataType* type_constraint_values,
1842
+ size_t type_constraint_count,
1843
+ const OpAttr* attr_values, size_t attr_count,
1844
+ size_t input_count, size_t output_count) {
1845
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1846
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1847
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1848
+ OrtOp* op;
1849
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1850
+ static_cast<int>(type_constraint_count),
1851
+ attr_input_values,
1852
+ static_cast<int>(attr_count),
1853
+ static_cast<int>(input_count),
1854
+ static_cast<int>(output_count), &op));
1855
+ return Op{op};
1856
+ }
1857
+
1858
+ inline void Op::Invoke(const OrtKernelContext* context,
1859
+ const Value* input_values,
1860
+ size_t input_count,
1861
+ Value* output_values,
1862
+ size_t output_count) {
1863
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1864
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1865
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1866
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1867
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1868
+ ort_output_values, static_cast<int>(output_count)));
1869
+ }
1870
+
1871
+ inline void Op::Invoke(const OrtKernelContext* context,
1872
+ const OrtValue* const* input_values,
1873
+ size_t input_count,
1874
+ OrtValue* const* output_values,
1875
+ size_t output_count) {
1876
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1877
+ output_values, static_cast<int>(output_count)));
1878
+ }
1879
+
1880
+ inline std::string GetVersionString() {
1881
+ return OrtGetApiBase()->GetVersionString();
1882
+ }
1883
+
1884
+ inline std::string GetBuildInfoString() {
1885
+ return GetApi().GetBuildInfoString();
1886
+ }
1887
+
1888
+ inline std::vector<std::string> GetAvailableProviders() {
1889
+ char** providers;
1890
+ int len;
1891
+
1892
+ auto release_fn = [&len](char** providers) {
1893
+ // This should always return nullptr.
1894
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1895
+ };
1896
+
1897
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1898
+ std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1899
+ std::vector<std::string> available_providers;
1900
+ available_providers.reserve(static_cast<size_t>(len));
1901
+ for (int i = 0; i < len; ++i) {
1902
+ available_providers.emplace_back(providers[i]);
1903
+ }
1904
+ return available_providers;
1905
+ }
1906
+
1907
+ template <typename TOp, typename TKernel, bool WithStatus>
1908
+ void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1909
+ ConstSessionOptions options) const {
1910
+ const TOp* derived = static_cast<const TOp*>(this);
1911
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
1912
+
1913
+ out.reserve(keys.size());
1914
+
1915
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1916
+ const size_t prefix_size = config_entry_key.length();
1917
+
1918
+ for (const auto& key : keys) {
1919
+ config_entry_key.resize(prefix_size);
1920
+ config_entry_key.append(key);
1921
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1922
+ }
1923
+ }
1924
+
1925
+ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
1926
+ OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
1927
+ size_t input_count = 0;
1928
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
1929
+ for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
1930
+ OrtTensorTypeAndShapeInfo* info{};
1931
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
1932
+ TensorTypeAndShapeInfo type_shape_info(info);
1933
+ auto integer_shape = type_shape_info.GetShape();
1934
+ std::vector<const char*> symbolic_shape(integer_shape.size(), {});
1935
+ type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
1936
+ Shape shape;
1937
+ for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
1938
+ if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
1939
+ shape.emplace_back(symbolic_shape[ith]);
1940
+ } else {
1941
+ shape.emplace_back(integer_shape[ith]);
1942
+ }
1943
+ }
1944
+ input_shapes_.push_back(std::move(shape));
1945
+ type_shape_info.release();
1946
+ }
1947
+ }
1948
+
1949
+ inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) {
1950
+ OrtTensorTypeAndShapeInfo* info = {};
1951
+ RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
1952
+
1953
+ using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
1954
+
1955
+ InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
1956
+ ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
1957
+ });
1958
+
1959
+ std::vector<int64_t> integer_dims;
1960
+ std::vector<const char*> symbolic_dims;
1961
+
1962
+ for (const auto dim : shape) {
1963
+ if (dim.IsInt()) {
1964
+ integer_dims.push_back(dim.IsInt());
1965
+ symbolic_dims.push_back("");
1966
+ } else {
1967
+ if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
1968
+ ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
1969
+ }
1970
+ integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
1971
+ symbolic_dims.push_back(dim.AsSym());
1972
+ }
1973
+ }
1974
+
1975
+ RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
1976
+ RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
1977
+ RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
1978
+ return Status{nullptr};
1979
+ }
1980
+
1981
+ inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
1982
+ const auto* attr = GetAttrHdl(attr_name);
1983
+ int64_t i = {};
1984
+ size_t out = {};
1985
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
1986
+ return i;
1987
+ }
1988
+
1989
+ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
1990
+ const auto* attr = GetAttrHdl(attr_name);
1991
+ int64_t i = {};
1992
+ size_t out = {};
1993
+ // first call to get the bytes needed
1994
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
1995
+ if (status) {
1996
+ size_t num_i = out / sizeof(int64_t);
1997
+ ShapeInferContext::Ints ints(num_i, 0);
1998
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
1999
+ return ints;
2000
+ } else {
2001
+ return {i};
2002
+ }
2003
+ }
2004
+
2005
+ inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
2006
+ const auto* attr = GetAttrHdl(attr_name);
2007
+ float f = {};
2008
+ size_t out = {};
2009
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
2010
+ return f;
2011
+ }
2012
+
2013
+ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
2014
+ const auto* attr = GetAttrHdl(attr_name);
2015
+ float f = {};
2016
+ size_t out = {};
2017
+ // first call to get the bytes needed
2018
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
2019
+ if (status) {
2020
+ size_t num_f = out / sizeof(float);
2021
+ ShapeInferContext::Floats floats(num_f, 0);
2022
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
2023
+ return floats;
2024
+ } else {
2025
+ return {f};
2026
+ }
2027
+ }
2028
+
2029
+ inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
2030
+ const auto* attr = GetAttrHdl(attr_name);
2031
+ char c = {};
2032
+ size_t out = {};
2033
+ // first call to get the bytes needed
2034
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
2035
+ if (status) {
2036
+ std::vector<char> chars(out, '\0');
2037
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
2038
+ return {chars.data()};
2039
+ } else {
2040
+ return {c};
2041
+ }
2042
+ }
2043
+
2044
+ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
2045
+ const auto* attr = GetAttrHdl(attr_name);
2046
+ char c = {};
2047
+ size_t out = {};
2048
+ // first call to get the bytes needed
2049
+ auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
2050
+ if (status) {
2051
+ std::vector<char> chars(out, '\0');
2052
+ Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
2053
+ ShapeInferContext::Strings strings;
2054
+ char* char_st = chars.data();
2055
+ char* char_ed = char_st + out;
2056
+ while (char_st < char_ed) {
2057
+ strings.emplace_back(char_st);
2058
+ while (*char_st != '\0') {
2059
+ char_st++;
2060
+ }
2061
+ char_st++;
2062
+ }
2063
+ return strings;
2064
+ } else {
2065
+ return {std::string{c}};
2066
+ }
2067
+ }
2068
+
2069
+ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
2070
+ const OrtOpAttr* attr_hdl = {};
2071
+ Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
2072
+ return attr_hdl;
2073
+ }
2074
+
2075
+ } // namespace Ort
v1.17.1/headers/onnxruntime_float16.h ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ #include <stdint.h>
7
+ #include <cmath>
8
+ #include <cstring>
9
+ #include <limits>
10
+
11
+ namespace onnxruntime_float16 {
12
+
13
+ namespace detail {
14
+
15
+ enum class endian {
16
+ #if defined(_WIN32)
17
+ little = 0,
18
+ big = 1,
19
+ native = little,
20
+ #elif defined(__GNUC__) || defined(__clang__)
21
+ little = __ORDER_LITTLE_ENDIAN__,
22
+ big = __ORDER_BIG_ENDIAN__,
23
+ native = __BYTE_ORDER__,
24
+ #else
25
+ #error onnxruntime_float16::detail::endian is not implemented in this environment.
26
+ #endif
27
+ };
28
+
29
+ static_assert(
30
+ endian::native == endian::little || endian::native == endian::big,
31
+ "Only little-endian or big-endian native byte orders are supported.");
32
+
33
+ } // namespace detail
34
+
35
+ /// <summary>
36
+ /// Shared implementation between public and internal classes. CRTP pattern.
37
+ /// </summary>
38
+ template <class Derived>
39
+ struct Float16Impl {
40
+ protected:
41
+ /// <summary>
42
+ /// Converts from float to uint16_t float16 representation
43
+ /// </summary>
44
+ /// <param name="v"></param>
45
+ /// <returns></returns>
46
+ constexpr static uint16_t ToUint16Impl(float v) noexcept;
47
+
48
+ /// <summary>
49
+ /// Converts float16 to float
50
+ /// </summary>
51
+ /// <returns>float representation of float16 value</returns>
52
+ float ToFloatImpl() const noexcept;
53
+
54
+ /// <summary>
55
+ /// Creates an instance that represents absolute value.
56
+ /// </summary>
57
+ /// <returns>Absolute value</returns>
58
+ uint16_t AbsImpl() const noexcept {
59
+ return static_cast<uint16_t>(val & ~kSignMask);
60
+ }
61
+
62
+ /// <summary>
63
+ /// Creates a new instance with the sign flipped.
64
+ /// </summary>
65
+ /// <returns>Flipped sign instance</returns>
66
+ uint16_t NegateImpl() const noexcept {
67
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
68
+ }
69
+
70
+ public:
71
+ // uint16_t special values
72
+ static constexpr uint16_t kSignMask = 0x8000U;
73
+ static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
74
+ static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
75
+ static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
76
+ static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
77
+ static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
78
+ static constexpr uint16_t kEpsilonBits = 0x4170U;
79
+ static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
80
+ static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
81
+ static constexpr uint16_t kOneBits = 0x3C00U;
82
+ static constexpr uint16_t kMinusOneBits = 0xBC00U;
83
+
84
+ uint16_t val{0};
85
+
86
+ Float16Impl() = default;
87
+
88
+ /// <summary>
89
+ /// Checks if the value is negative
90
+ /// </summary>
91
+ /// <returns>true if negative</returns>
92
+ bool IsNegative() const noexcept {
93
+ return static_cast<int16_t>(val) < 0;
94
+ }
95
+
96
+ /// <summary>
97
+ /// Tests if the value is NaN
98
+ /// </summary>
99
+ /// <returns>true if NaN</returns>
100
+ bool IsNaN() const noexcept {
101
+ return AbsImpl() > kPositiveInfinityBits;
102
+ }
103
+
104
+ /// <summary>
105
+ /// Tests if the value is finite
106
+ /// </summary>
107
+ /// <returns>true if finite</returns>
108
+ bool IsFinite() const noexcept {
109
+ return AbsImpl() < kPositiveInfinityBits;
110
+ }
111
+
112
+ /// <summary>
113
+ /// Tests if the value represents positive infinity.
114
+ /// </summary>
115
+ /// <returns>true if positive infinity</returns>
116
+ bool IsPositiveInfinity() const noexcept {
117
+ return val == kPositiveInfinityBits;
118
+ }
119
+
120
+ /// <summary>
121
+ /// Tests if the value represents negative infinity
122
+ /// </summary>
123
+ /// <returns>true if negative infinity</returns>
124
+ bool IsNegativeInfinity() const noexcept {
125
+ return val == kNegativeInfinityBits;
126
+ }
127
+
128
+ /// <summary>
129
+ /// Tests if the value is either positive or negative infinity.
130
+ /// </summary>
131
+ /// <returns>True if absolute value is infinity</returns>
132
+ bool IsInfinity() const noexcept {
133
+ return AbsImpl() == kPositiveInfinityBits;
134
+ }
135
+
136
+ /// <summary>
137
+ /// Tests if the value is NaN or zero. Useful for comparisons.
138
+ /// </summary>
139
+ /// <returns>True if NaN or zero.</returns>
140
+ bool IsNaNOrZero() const noexcept {
141
+ auto abs = AbsImpl();
142
+ return (abs == 0 || abs > kPositiveInfinityBits);
143
+ }
144
+
145
+ /// <summary>
146
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
147
+ /// </summary>
148
+ /// <returns>True if so</returns>
149
+ bool IsNormal() const noexcept {
150
+ auto abs = AbsImpl();
151
+ return (abs < kPositiveInfinityBits) // is finite
152
+ && (abs != 0) // is not zero
153
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
154
+ }
155
+
156
+ /// <summary>
157
+ /// Tests if the value is subnormal (denormal).
158
+ /// </summary>
159
+ /// <returns>True if so</returns>
160
+ bool IsSubnormal() const noexcept {
161
+ auto abs = AbsImpl();
162
+ return (abs < kPositiveInfinityBits) // is finite
163
+ && (abs != 0) // is not zero
164
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
165
+ }
166
+
167
+ /// <summary>
168
+ /// Creates an instance that represents absolute value.
169
+ /// </summary>
170
+ /// <returns>Absolute value</returns>
171
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
172
+
173
+ /// <summary>
174
+ /// Creates a new instance with the sign flipped.
175
+ /// </summary>
176
+ /// <returns>Flipped sign instance</returns>
177
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
178
+
179
+ /// <summary>
180
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
181
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
182
+ /// and therefore equivalent, if the resulting value is still zero.
183
+ /// </summary>
184
+ /// <param name="lhs">first value</param>
185
+ /// <param name="rhs">second value</param>
186
+ /// <returns>True if both arguments represent zero</returns>
187
+ static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept {
188
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
189
+ }
190
+
191
+ bool operator==(const Float16Impl& rhs) const noexcept {
192
+ if (IsNaN() || rhs.IsNaN()) {
193
+ // IEEE defines that NaN is not equal to anything, including itself.
194
+ return false;
195
+ }
196
+ return val == rhs.val;
197
+ }
198
+
199
+ bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
200
+
201
+ bool operator<(const Float16Impl& rhs) const noexcept {
202
+ if (IsNaN() || rhs.IsNaN()) {
203
+ // IEEE defines that NaN is unordered with respect to everything, including itself.
204
+ return false;
205
+ }
206
+
207
+ const bool left_is_negative = IsNegative();
208
+ if (left_is_negative != rhs.IsNegative()) {
209
+ // When the signs of left and right differ, we know that left is less than right if it is
210
+ // the negative value. The exception to this is if both values are zero, in which case IEEE
211
+ // says they should be equal, even if the signs differ.
212
+ return left_is_negative && !AreZero(*this, rhs);
213
+ }
214
+ return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
215
+ }
216
+ };
217
+
218
+ // The following Float16_t conversions are based on the code from
219
+ // Eigen library.
220
+
221
+ // The conversion routines are Copyright (c) Fabian Giesen, 2016.
222
+ // The original license follows:
223
+ //
224
+ // Copyright (c) Fabian Giesen, 2016
225
+ // All rights reserved.
226
+ // Redistribution and use in source and binary forms, with or without
227
+ // modification, are permitted.
228
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
229
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
230
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
231
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
232
+ // HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
233
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
234
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
235
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
236
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
237
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
238
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
239
+
240
+ namespace detail {
241
+ union float32_bits {
242
+ unsigned int u;
243
+ float f;
244
+ };
245
+ } // namespace detail
246
+
247
+ template <class Derived>
248
+ inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept {
249
+ detail::float32_bits f{};
250
+ f.f = v;
251
+
252
+ constexpr detail::float32_bits f32infty = {255 << 23};
253
+ constexpr detail::float32_bits f16max = {(127 + 16) << 23};
254
+ constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
255
+ constexpr unsigned int sign_mask = 0x80000000u;
256
+ uint16_t val = static_cast<uint16_t>(0x0u);
257
+
258
+ unsigned int sign = f.u & sign_mask;
259
+ f.u ^= sign;
260
+
261
+ // NOTE all the integer compares in this function can be safely
262
+ // compiled into signed compares since all operands are below
263
+ // 0x80000000. Important if you want fast straight SSE2 code
264
+ // (since there's no unsigned PCMPGTD).
265
+
266
+ if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
267
+ val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
268
+ } else { // (De)normalized number or zero
269
+ if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
270
+ // use a magic value to align our 10 mantissa bits at the bottom of
271
+ // the float. as long as FP addition is round-to-nearest-even this
272
+ // just works.
273
+ f.f += denorm_magic.f;
274
+
275
+ // and one integer subtract of the bias later, we have our final float!
276
+ val = static_cast<uint16_t>(f.u - denorm_magic.u);
277
+ } else {
278
+ unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
279
+
280
+ // update exponent, rounding bias part 1
281
+ // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
282
+ // without arithmetic overflow.
283
+ f.u += 0xc8000fffU;
284
+ // rounding bias part 2
285
+ f.u += mant_odd;
286
+ // take the bits!
287
+ val = static_cast<uint16_t>(f.u >> 13);
288
+ }
289
+ }
290
+
291
+ val |= static_cast<uint16_t>(sign >> 16);
292
+ return val;
293
+ }
294
+
295
+ template <class Derived>
296
+ inline float Float16Impl<Derived>::ToFloatImpl() const noexcept {
297
+ constexpr detail::float32_bits magic = {113 << 23};
298
+ constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
299
+ detail::float32_bits o{};
300
+
301
+ o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
302
+ unsigned int exp = shifted_exp & o.u; // just the exponent
303
+ o.u += (127 - 15) << 23; // exponent adjust
304
+
305
+ // handle exponent special cases
306
+ if (exp == shifted_exp) { // Inf/NaN?
307
+ o.u += (128 - 16) << 23; // extra exp adjust
308
+ } else if (exp == 0) { // Zero/Denormal?
309
+ o.u += 1 << 23; // extra exp adjust
310
+ o.f -= magic.f; // re-normalize
311
+ }
312
+
313
+ // Attempt to workaround the Internal Compiler Error on ARM64
314
+ // for bitwise | operator, including std::bitset
315
+ #if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
316
+ if (IsNegative()) {
317
+ return -o.f;
318
+ }
319
+ #else
320
+ // original code:
321
+ o.u |= (val & 0x8000U) << 16U; // sign bit
322
+ #endif
323
+ return o.f;
324
+ }
325
+
326
+ /// Shared implementation between public and internal classes. CRTP pattern.
327
+ template <class Derived>
328
+ struct BFloat16Impl {
329
+ protected:
330
+ /// <summary>
331
+ /// Converts from float to uint16_t float16 representation
332
+ /// </summary>
333
+ /// <param name="v"></param>
334
+ /// <returns></returns>
335
+ static uint16_t ToUint16Impl(float v) noexcept;
336
+
337
+ /// <summary>
338
+ /// Converts bfloat16 to float
339
+ /// </summary>
340
+ /// <returns>float representation of bfloat16 value</returns>
341
+ float ToFloatImpl() const noexcept;
342
+
343
+ /// <summary>
344
+ /// Creates an instance that represents absolute value.
345
+ /// </summary>
346
+ /// <returns>Absolute value</returns>
347
+ uint16_t AbsImpl() const noexcept {
348
+ return static_cast<uint16_t>(val & ~kSignMask);
349
+ }
350
+
351
+ /// <summary>
352
+ /// Creates a new instance with the sign flipped.
353
+ /// </summary>
354
+ /// <returns>Flipped sign instance</returns>
355
+ uint16_t NegateImpl() const noexcept {
356
+ return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
357
+ }
358
+
359
+ public:
360
+ // uint16_t special values
361
+ static constexpr uint16_t kSignMask = 0x8000U;
362
+ static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
363
+ static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
364
+ static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
365
+ static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
366
+ static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
367
+ static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
368
+ static constexpr uint16_t kEpsilonBits = 0x0080U;
369
+ static constexpr uint16_t kMinValueBits = 0xFF7FU;
370
+ static constexpr uint16_t kMaxValueBits = 0x7F7FU;
371
+ static constexpr uint16_t kRoundToNearest = 0x7FFFU;
372
+ static constexpr uint16_t kOneBits = 0x3F80U;
373
+ static constexpr uint16_t kMinusOneBits = 0xBF80U;
374
+
375
+ uint16_t val{0};
376
+
377
+ BFloat16Impl() = default;
378
+
379
+ /// <summary>
380
+ /// Checks if the value is negative
381
+ /// </summary>
382
+ /// <returns>true if negative</returns>
383
+ bool IsNegative() const noexcept {
384
+ return static_cast<int16_t>(val) < 0;
385
+ }
386
+
387
+ /// <summary>
388
+ /// Tests if the value is NaN
389
+ /// </summary>
390
+ /// <returns>true if NaN</returns>
391
+ bool IsNaN() const noexcept {
392
+ return AbsImpl() > kPositiveInfinityBits;
393
+ }
394
+
395
+ /// <summary>
396
+ /// Tests if the value is finite
397
+ /// </summary>
398
+ /// <returns>true if finite</returns>
399
+ bool IsFinite() const noexcept {
400
+ return AbsImpl() < kPositiveInfinityBits;
401
+ }
402
+
403
+ /// <summary>
404
+ /// Tests if the value represents positive infinity.
405
+ /// </summary>
406
+ /// <returns>true if positive infinity</returns>
407
+ bool IsPositiveInfinity() const noexcept {
408
+ return val == kPositiveInfinityBits;
409
+ }
410
+
411
+ /// <summary>
412
+ /// Tests if the value represents negative infinity
413
+ /// </summary>
414
+ /// <returns>true if negative infinity</returns>
415
+ bool IsNegativeInfinity() const noexcept {
416
+ return val == kNegativeInfinityBits;
417
+ }
418
+
419
+ /// <summary>
420
+ /// Tests if the value is either positive or negative infinity.
421
+ /// </summary>
422
+ /// <returns>True if absolute value is infinity</returns>
423
+ bool IsInfinity() const noexcept {
424
+ return AbsImpl() == kPositiveInfinityBits;
425
+ }
426
+
427
+ /// <summary>
428
+ /// Tests if the value is NaN or zero. Useful for comparisons.
429
+ /// </summary>
430
+ /// <returns>True if NaN or zero.</returns>
431
+ bool IsNaNOrZero() const noexcept {
432
+ auto abs = AbsImpl();
433
+ return (abs == 0 || abs > kPositiveInfinityBits);
434
+ }
435
+
436
+ /// <summary>
437
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
438
+ /// </summary>
439
+ /// <returns>True if so</returns>
440
+ bool IsNormal() const noexcept {
441
+ auto abs = AbsImpl();
442
+ return (abs < kPositiveInfinityBits) // is finite
443
+ && (abs != 0) // is not zero
444
+ && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
445
+ }
446
+
447
+ /// <summary>
448
+ /// Tests if the value is subnormal (denormal).
449
+ /// </summary>
450
+ /// <returns>True if so</returns>
451
+ bool IsSubnormal() const noexcept {
452
+ auto abs = AbsImpl();
453
+ return (abs < kPositiveInfinityBits) // is finite
454
+ && (abs != 0) // is not zero
455
+ && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
456
+ }
457
+
458
+ /// <summary>
459
+ /// Creates an instance that represents absolute value.
460
+ /// </summary>
461
+ /// <returns>Absolute value</returns>
462
+ Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
463
+
464
+ /// <summary>
465
+ /// Creates a new instance with the sign flipped.
466
+ /// </summary>
467
+ /// <returns>Flipped sign instance</returns>
468
+ Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
469
+
470
+ /// <summary>
471
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
472
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
473
+ /// and therefore equivalent, if the resulting value is still zero.
474
+ /// </summary>
475
+ /// <param name="lhs">first value</param>
476
+ /// <param name="rhs">second value</param>
477
+ /// <returns>True if both arguments represent zero</returns>
478
+ static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept {
479
+ // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
480
+ // for two values by or'ing the private bits together and stripping the sign. They are both zero,
481
+ // and therefore equivalent, if the resulting value is still zero.
482
+ return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
483
+ }
484
+ };
485
+
486
+ template <class Derived>
487
+ inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept {
488
+ uint16_t result;
489
+ if (std::isnan(v)) {
490
+ result = kPositiveQNaNBits;
491
+ } else {
492
+ auto get_msb_half = [](float fl) {
493
+ uint16_t result;
494
+ #ifdef __cpp_if_constexpr
495
+ if constexpr (detail::endian::native == detail::endian::little) {
496
+ #else
497
+ if (detail::endian::native == detail::endian::little) {
498
+ #endif
499
+ std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
500
+ } else {
501
+ std::memcpy(&result, &fl, sizeof(uint16_t));
502
+ }
503
+ return result;
504
+ };
505
+
506
+ uint16_t upper_bits = get_msb_half(v);
507
+ union {
508
+ uint32_t U32;
509
+ float F32;
510
+ };
511
+ F32 = v;
512
+ U32 += (upper_bits & 1) + kRoundToNearest;
513
+ result = get_msb_half(F32);
514
+ }
515
+ return result;
516
+ }
517
+
518
+ template <class Derived>
519
+ inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept {
520
+ if (IsNaN()) {
521
+ return std::numeric_limits<float>::quiet_NaN();
522
+ }
523
+ float result;
524
+ char* const first = reinterpret_cast<char*>(&result);
525
+ char* const second = first + sizeof(uint16_t);
526
+ #ifdef __cpp_if_constexpr
527
+ if constexpr (detail::endian::native == detail::endian::little) {
528
+ #else
529
+ if (detail::endian::native == detail::endian::little) {
530
+ #endif
531
+ std::memset(first, 0, sizeof(uint16_t));
532
+ std::memcpy(second, &val, sizeof(uint16_t));
533
+ } else {
534
+ std::memcpy(first, &val, sizeof(uint16_t));
535
+ std::memset(second, 0, sizeof(uint16_t));
536
+ }
537
+ return result;
538
+ }
539
+
540
+ } // namespace onnxruntime_float16
v1.17.1/headers/onnxruntime_run_options_config_keys.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines RunOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a RunOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 128
14
+ *
15
+ * The string format of a RunOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 1024
17
+ */
18
+
19
+ // Key for enabling shrinkages of user listed device memory arenas.
20
+ // Expects a list of semi-colon separated key value pairs separated by colon in the following format:
21
+ // "device_0:device_id_0;device_1:device_id_1"
22
+ // No white-spaces allowed in the provided list string.
23
+ // Currently, the only supported devices are : "cpu", "gpu" (case sensitive).
24
+ // If "cpu" is included in the list, DisableCpuMemArena() API must not be called (i.e.) arena for cpu should be enabled.
25
+ // Example usage: "cpu:0;gpu:0" (or) "gpu:0"
26
+ // By default, the value for this key is empty (i.e.) no memory arenas are shrunk
27
+ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memory.enable_memory_arena_shrinkage";
28
+
29
+ // Set to '1' to not synchronize execution providers with CPU at the end of session run.
30
+ // Per default it will be set to '0'
31
+ // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
32
+ static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
v1.17.1/headers/onnxruntime_session_options_config_keys.h ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #pragma once
5
+
6
+ /*
7
+ * This file defines SessionOptions Config Keys and format of the Config Values.
8
+ *
9
+ * The Naming Convention for a SessionOptions Config Key,
10
+ * "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
11
+ * Such as "ep.cuda.use_arena"
12
+ * The Config Key cannot be empty
13
+ * The maximum length of the Config Key is 128
14
+ *
15
+ * The string format of a SessionOptions Config Value is defined individually for each Config.
16
+ * The maximum length of the Config Value is 1024
17
+ */
18
+
19
+ // Key for disable PrePacking,
20
+ // If the config value is set to "1" then the prepacking is disabled, otherwise prepacking is enabled (default value)
21
+ static const char* const kOrtSessionOptionsConfigDisablePrepacking = "session.disable_prepacking";
22
+
23
+ // A value of "1" means allocators registered in the env will be used. "0" means the allocators created in the session
24
+ // will be used. Use this to override the usage of env allocators on a per session level.
25
+ static const char* const kOrtSessionOptionsConfigUseEnvAllocators = "session.use_env_allocators";
26
+
27
+ // Set to 'ORT' (case sensitive) to load an ORT format model.
28
+ // If unset, model type will default to ONNX unless inferred from filename ('.ort' == ORT format) or bytes to be ORT
29
+ static const char* const kOrtSessionOptionsConfigLoadModelFormat = "session.load_model_format";
30
+
31
+ // Set to 'ORT' (case sensitive) to save optimized model in ORT format when SessionOptions.optimized_model_path is set.
32
+ // If unset, format will default to ONNX unless optimized_model_filepath ends in '.ort'.
33
+ static const char* const kOrtSessionOptionsConfigSaveModelFormat = "session.save_model_format";
34
+
35
+ // If a value is "1", flush-to-zero and denormal-as-zero are applied. The default is "0".
36
+ // When multiple sessions are created, a main thread doesn't override changes from succeeding session options,
37
+ // but threads in session thread pools follow option changes.
38
+ // When ORT runs with OpenMP, the same rule is applied, i.e. the first session option to flush-to-zero and
39
+ // denormal-as-zero is only applied to global OpenMP thread pool, which doesn't support per-session thread pool.
40
+ // Note that an alternative way not using this option at runtime is to train and export a model without denormals
41
+ // and that's recommended because turning this option on may hurt model accuracy.
42
+ static const char* const kOrtSessionOptionsConfigSetDenormalAsZero = "session.set_denormal_as_zero";
43
+
44
+ // It controls to run quantization model in QDQ (QuantizelinearDeQuantizelinear) format or not.
45
+ // "0": enable. ORT does fusion logic for QDQ format.
46
+ // "1": disable. ORT doesn't do fusion logic for QDQ format.
47
+ // Its default value is "0" unless the DirectML execution provider is registered, in which case it defaults to "1".
48
+ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_quant_qdq";
49
+
50
+ // It controls whether to enable Double QDQ remover and Identical Children Consolidation
51
+ // "0": not to disable. ORT does remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
52
+ // "1": disable. ORT doesn't remove the middle 2 Nodes from a Q->(QD->Q)->QD pairs
53
+ // Its default value is "0"
54
+ static const char* const kOrtSessionOptionsDisableDoubleQDQRemover = "session.disable_double_qdq_remover";
55
+
56
+ // If set to "1", enables the removal of QuantizeLinear/DequantizeLinear node pairs once all QDQ handling has been
57
+ // completed. e.g. If after all QDQ handling has completed and we have -> FloatOp -> Q -> DQ -> FloatOp -> the
58
+ // Q -> DQ could potentially be removed. This will provide a performance benefit by avoiding going from float to
59
+ // 8-bit and back to float, but could impact accuracy. The impact on accuracy will be model specific and depend on
60
+ // other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
61
+ // As such, it's best to test to determine if enabling this works well for your scenario.
62
+ // The default value is "0"
63
+ // Available since version 1.11.
64
+ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";
65
+
66
+ // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
67
+ // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
68
+ static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
69
+
70
+ // This setting controls whether to enable AheadOfTime function inlining.
71
+ // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
72
+ // as possible with the help of enabled execution providers.
73
+ // This can reduce the number of function calls and improve performance because it is done before
74
+ // Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
75
+ // one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
76
+ // "0": enable; "1": disable.
77
+ // Its default value is "0".
78
+ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
79
+
80
+ #ifdef ENABLE_TRAINING
81
+ // Specifies a list of op types for memory footprint reduction.
82
+ // The value should be a ","-delimited list of pair of
83
+ // <subgraph string: optimization strategy: number of subgraph to apply>.
84
+ // For example, "Gelu+Cast+:1:0,Dropout+:1:1".
85
+ // A valid "subgraph string" should be one subgraph representation output by ORT graph transformations.
86
+ // "optimization strategy" currently has valid values: 0 - disabled, 1 - recompute.
87
+ // "number of subgraph to apply" is used to control how many subgraphs to apply optimization, to avoid "oversaving"
88
+ // the memory.
89
+ static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config";
90
+
91
+ // Specifies the config for detecting subgraphs for memory footprint reduction.
92
+ // The value should be a string contains int separated using commas. The default value is "0:0".
93
+ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
94
+ #endif
95
+
96
+ // Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
97
+ // Using device allocators means the memory allocation is made using malloc/new.
98
+ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "session.use_device_allocator_for_initializers";
99
+
100
+ // Configure whether to allow the inter_op/intra_op threads spinning a number of times before blocking
101
+ // "0": thread will block if found no job to run
102
+ // "1": default, thread will spin a number of times before blocking
103
+ static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
104
+ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
105
+
106
+ // Key for using model bytes directly for ORT format
107
+ // If a session is created using an input byte array contains the ORT format model data,
108
+ // By default we will copy the model bytes at the time of session creation to ensure the model bytes
109
+ // buffer is valid.
110
+ // Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
111
+ // has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
112
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
113
+
114
+ /// <summary>
115
+ /// Key for using the ORT format model flatbuffer bytes directly for initializers.
116
+ /// This avoids copying the bytes and reduces peak memory usage during model loading and initialization.
117
+ /// Requires `session.use_ort_model_bytes_directly` to be true.
118
+ /// If set, the flatbuffer bytes provided when creating the InferenceSession MUST remain valid for the entire
119
+ /// duration of the InferenceSession.
120
+ /// </summary>
121
+ static const char* const kOrtSessionOptionsConfigUseORTModelBytesForInitializers =
122
+ "session.use_ort_model_bytes_for_initializers";
123
+
124
+ // This should only be specified when exporting an ORT format model for use on a different platform.
125
+ // If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
126
+ // Available since version 1.11.
127
+ static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";
128
+
129
+ // x64 SSE4.1/AVX2/AVX512(with no VNNI) has overflow problem with quantizied matrix multiplication with U8S8.
130
+ // To avoid this we need to use slower U8U8 matrix multiplication instead. This option, if
131
+ // turned on, use slower U8U8 matrix multiplications. Only effective with AVX2 or AVX512
132
+ // platforms.
133
+ static const char* const kOrtSessionOptionsAvx2PrecisionMode = "session.x64quantprecision";
134
+
135
+ // Specifies how minimal build graph optimizations are handled in a full build.
136
+ // These optimizations are at the extended level or higher.
137
+ // Possible values and their effects are:
138
+ // "save": Save runtime optimizations when saving an ORT format model.
139
+ // "apply": Only apply optimizations available in a minimal build.
140
+ // ""/<unspecified>: Apply optimizations available in a full build.
141
+ // Available since version 1.11.
142
+ static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
143
+ "optimization.minimal_build_optimizations";
144
+
145
+ // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
146
+ // order for them to take effect.
147
+
148
+ // Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
149
+ // run by the NNAPI EP.
150
+ // The value should be a ","-delimited list of op types. For example, "Add,Sub".
151
+ // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
152
+ // exclusion, set the value to "".
153
+ static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
154
+
155
+ // Enabling dynamic block-sizing for multithreading.
156
+ // With a positive value, thread pool will split a task of N iterations to blocks of size starting from:
157
+ // N / (num_of_threads * dynamic_block_base)
158
+ // As execution progresses, the size will decrease according to the diminishing residual of N,
159
+ // meaning the task will be distributed in smaller granularity for better parallelism.
160
+ // For some models, it helps to reduce the variance of E2E inference latency and boost performance.
161
+ // The feature will not function by default, specify any positive integer, e.g. "4", to enable it.
162
+ // Available since version 1.11.
163
+ static const char* const kOrtSessionOptionsConfigDynamicBlockBase = "session.dynamic_block_base";
164
+
165
+ // This option allows to decrease CPU usage between infrequent
166
+ // requests and forces any TP threads spinning stop immediately when the last of
167
+ // concurrent Run() call returns.
168
+ // Spinning is restarted on the next Run() call.
169
+ // Applies only to internal thread-pools
170
+ static const char* const kOrtSessionOptionsConfigForceSpinningStop = "session.force_spinning_stop";
171
+
172
+ // "1": all inconsistencies encountered during shape and type inference
173
+ // will result in failures.
174
+ // "0": in some cases warnings will be logged but processing will continue. The default.
175
+ // May be useful to expose bugs in models.
176
+ static const char* const kOrtSessionOptionsConfigStrictShapeTypeInference = "session.strict_shape_type_inference";
177
+
178
+ // "1": every model using a more recent opset than the latest released one will fail
179
+ // "0": the model may or may not work if onnxruntime cannot find an implementation, this option
180
+ // is used for development purpose.
181
+ static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly = "session.allow_released_opsets_only";
182
+
183
+ // The file saves configuration for partitioning node among logic streams
184
+ static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";
185
+
186
+ // This Option allows setting affinities for intra op threads.
187
+ // Affinity string follows format:
188
+ // logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
189
+ // Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to.
190
+ // e.g.1,2,3;4,5
191
+ // specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th.
192
+ // To ease the configuration, an "interval" is also allowed:
193
+ // e.g. 1-8;8-16;17-24
194
+ // orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth.
195
+ // Note:
196
+ // 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, since ort does not set affinity on the main thread which
197
+ // is started and managed by the calling app;
198
+ // 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors,
199
+ // an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group.
200
+ // Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary.
201
+ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "session.intra_op_thread_affinities";
202
+
203
+ // This option will dump out the model to assist debugging any issues with layout transformation,
204
+ // and is primarily intended for developer usage. It is only relevant if an execution provider that requests
205
+ // NHWC layout is enabled such as NNAPI, XNNPACK or QNN.
206
+ //
207
+ // Default is off. Set to "1" to enable.
208
+ //
209
+ // If modified by layout transformation the model will be dumped after these steps:
210
+ // 1) insertion of the layout transformation Transpose nodes
211
+ // 2) after those are optimized using the transpose optimizer,
212
+ // 3) after the L1 transformers are applied to the updated graph.
213
+ // The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
214
+ static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
215
+
216
+ // Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
217
+ // assigned (i.e., "fallback") to the CPU EP by default.
218
+ //
219
+ // This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
220
+ // If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
221
+ // fully support all of the nodes in the graph.
222
+ //
223
+ // It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
224
+ // will also fail with an error.
225
+ //
226
+ // Option values:
227
+ // - "0": CPU EP fallback is not disabled. [DEFAULT]
228
+ // - "1": CPU EP fallback is disabled.
229
+ static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
230
+
231
+ // Use this config when serializing a large model after optimization to specify an external initializers file
232
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFileName =
233
+ "session.optimized_model_external_initializers_file_name";
234
+
235
+ // Use this config to control the minimum size of the initializer when externalizing it during serialization
236
+ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
237
+ "session.optimized_model_external_initializers_min_size_in_bytes";
238
+
239
+ // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
240
+ // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
241
+ // "0": disable. (default)
242
+ // "1": enable.
243
+ static const char* const kOrtSessionOptionEpContextEnable = "ep.context_enable";
244
+
245
+ // Specify the file path for the Onnx model which has EP context.
246
+ // Default to original_file_name_ctx.onnx if not specified
247
+ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_path";
248
+
249
+ // Flag to specify whether to dump the EP context into the Onnx model.
250
+ // "0": dump the EP context into separate file, keep the file name in the Onnx model.
251
+ // "1": dump the EP context into the Onnx model. (default).
252
+ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
253
+
254
+ // Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
255
+ // Option values:
256
+ // - "0": Gemm FastMath mode is not enabled. [DEFAULT]
257
+ // - "1": Gemm FastMath mode is enabled.
258
+ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
v1.17.1/jni/arm64-v8a/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f39b3c5a8f4d6a3bb4828676a5fdb3e727423fec7b643b94fe247a4bd8dffbbf
3
+ size 16035920
v1.17.1/jni/arm64-v8a/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fc7e20fba98c5601582ae14c8c7fed194644856fc315d6184ed1776c50d9272
3
+ size 760104
v1.17.1/jni/armeabi-v7a/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54b0d427246341bc16e36b8ebce078c10464b11b40a6996e14ba57624ed8785e
3
+ size 10739204
v1.17.1/jni/armeabi-v7a/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ea6c2eb44579712c810e47914dbc4808a1d4faf82226a612500fe9c1f82115d
3
+ size 620088
v1.17.1/jni/x86/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:522b39cce6cd34e3c76203bcdf5ae4b4b3628b53dde77a84baee9c34310d21de
3
+ size 18012000
v1.17.1/jni/x86/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a33b3d363f2007f52f5c3d3f6f8f8ca3813be57270399148b1d98252d456e8e5
3
+ size 604204
v1.17.1/jni/x86_64/libonnxruntime.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed795517db637178ccf7dec5de45c201a68d67142d88bf269e02347cc485d327
3
+ size 17800912
v1.17.1/jni/x86_64/libonnxruntime4j_jni.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b15d098b51e44f0e3809b3ef4d72e51a2649f8fe8b06e373dafb58a31b72c8d
3
+ size 744104