csukuangfj commited on
Commit
fa18c68
·
1 Parent(s): b29c13f

Add onnxruntime.xcframework 1.15.0

Browse files
1.15.0/onnxruntime.xcframework/Headers/coreml_provider_factory.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // COREMLFlags are bool options we want to set for CoreML EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
10
+ // uint32_t coreml_flags = 0;
11
+ // coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
12
+ enum COREMLFlags {
13
+ COREML_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using CPU only in CoreML EP, this may decrease the perf but will provide
16
+ // reference output value without precision loss, which is useful for validation
17
+ COREML_FLAG_USE_CPU_ONLY = 0x001,
18
+
19
+ // Enable CoreML EP on subgraph
20
+ COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
21
+
22
+ // By default CoreML Execution provider will be enabled for all compatible Apple devices
23
+ // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
24
+ // Please note, enable this option does not guarantee the entire model to be executed using ANE only
25
+ COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
26
+
27
+ // Keep COREML_FLAG_MAX at the end of the enum definition
28
+ // And assign the last COREMLFlag to it
29
+ COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
30
+ };
31
+
32
+ #ifdef __cplusplus
33
+ extern "C" {
34
+ #endif
35
+
36
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
37
+ _In_ OrtSessionOptions* options, uint32_t coreml_flags);
38
+
39
+ #ifdef __cplusplus
40
+ }
41
+ #endif
1.15.0/onnxruntime.xcframework/Headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
1.15.0/onnxruntime.xcframework/Headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.15.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.15.0/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,2035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ namespace Ort {
11
+
12
+ namespace detail {
13
+ inline void ThrowStatus(const Status& st) {
14
+ std::string error_message = st.GetErrorMessage();
15
+ OrtErrorCode error_code = st.GetErrorCode();
16
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
17
+ }
18
+ } // namespace detail
19
+
20
+ inline void ThrowOnError(OrtStatus* ort_status) {
21
+ if (ort_status) {
22
+ Ort::Status st(ort_status);
23
+ detail::ThrowStatus(st);
24
+ }
25
+ }
26
+
27
+ inline void ThrowOnError(const Status& st) {
28
+ if (st) {
29
+ detail::ThrowStatus(st);
30
+ }
31
+ }
32
+
33
+ inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
34
+ }
35
+
36
+ inline Status::Status(const std::exception& e) noexcept {
37
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
38
+ }
39
+
40
+ inline Status::Status(const Exception& e) noexcept {
41
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
42
+ }
43
+
44
+ inline Status::Status(const char* message, OrtErrorCode code) noexcept {
45
+ p_ = GetApi().CreateStatus(code, message);
46
+ }
47
+
48
+ inline std::string Status::GetErrorMessage() const {
49
+ std::string message(GetApi().GetErrorMessage(p_));
50
+ return message;
51
+ }
52
+
53
+ inline OrtErrorCode Status::GetErrorCode() const {
54
+ return GetApi().GetErrorCode(p_);
55
+ }
56
+
57
+ inline bool Status::IsOK() const noexcept {
58
+ return (p_ == nullptr);
59
+ }
60
+
61
+ // This template converts a C++ type into it's ONNXTensorElementDataType
62
+ template <typename T>
63
+ struct TypeToTensorType;
64
+ template <>
65
+ struct TypeToTensorType<float> {
66
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
67
+ };
68
+ template <>
69
+ struct TypeToTensorType<Float16_t> {
70
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
71
+ };
72
+ template <>
73
+ struct TypeToTensorType<BFloat16_t> {
74
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
75
+ };
76
+ template <>
77
+ struct TypeToTensorType<double> {
78
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
79
+ };
80
+ template <>
81
+ struct TypeToTensorType<int8_t> {
82
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
83
+ };
84
+ template <>
85
+ struct TypeToTensorType<int16_t> {
86
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
87
+ };
88
+ template <>
89
+ struct TypeToTensorType<int32_t> {
90
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
91
+ };
92
+ template <>
93
+ struct TypeToTensorType<int64_t> {
94
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
95
+ };
96
+ template <>
97
+ struct TypeToTensorType<uint8_t> {
98
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
99
+ };
100
+ template <>
101
+ struct TypeToTensorType<uint16_t> {
102
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
103
+ };
104
+ template <>
105
+ struct TypeToTensorType<uint32_t> {
106
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
107
+ };
108
+ template <>
109
+ struct TypeToTensorType<uint64_t> {
110
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
111
+ };
112
+ template <>
113
+ struct TypeToTensorType<bool> {
114
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
115
+ };
116
+
117
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
118
+ : allocator_(allocator), p_(p), size_(size) {
119
+ }
120
+
121
+ inline MemoryAllocation::~MemoryAllocation() {
122
+ if (p_ != nullptr) {
123
+ // We do not throw out of destructor
124
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
125
+ static_cast<void>(ret);
126
+ }
127
+ }
128
+
129
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
130
+ *this = std::move(o);
131
+ }
132
+
133
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
134
+ OrtAllocator* alloc = nullptr;
135
+ void* p = nullptr;
136
+ size_t sz = 0;
137
+
138
+ // Swap out this
139
+ std::swap(alloc, allocator_);
140
+ std::swap(p, p_);
141
+ std::swap(sz, size_);
142
+
143
+ // Swap with incoming
144
+ std::swap(allocator_, o.allocator_);
145
+ std::swap(p_, o.p_);
146
+ std::swap(size_, o.size_);
147
+
148
+ // Destroy this instance if needed
149
+ MemoryAllocation this_alloc(alloc, p, sz);
150
+ return *this;
151
+ }
152
+
153
+ namespace detail {
154
+
155
+ template <typename T>
156
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
157
+ void* out;
158
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
159
+ return out;
160
+ }
161
+
162
+ template <typename T>
163
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
164
+ void* out;
165
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
166
+ MemoryAllocation result(this->p_, out, size);
167
+ return result;
168
+ }
169
+
170
+ template <typename T>
171
+ inline void AllocatorImpl<T>::Free(void* p) {
172
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
173
+ }
174
+
175
+ template <typename T>
176
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
177
+ const OrtMemoryInfo* out;
178
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
179
+ return ConstMemoryInfo{out};
180
+ }
181
+
182
+ } // namespace detail
183
+
184
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
185
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
186
+ }
187
+
188
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
189
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
190
+ }
191
+
192
+ namespace detail {
193
+
194
+ template <typename T>
195
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
196
+ const char* name = nullptr;
197
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
198
+ return std::string(name);
199
+ }
200
+
201
+ template <typename T>
202
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
203
+ OrtAllocatorType type;
204
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
205
+ return type;
206
+ }
207
+
208
+ template <typename T>
209
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
210
+ int id = 0;
211
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
212
+ return id;
213
+ }
214
+
215
+ template <typename T>
216
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
217
+ OrtMemoryInfoDeviceType type;
218
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
219
+ return type;
220
+ }
221
+
222
+ template <typename T>
223
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
224
+ OrtMemType type;
225
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
226
+ return type;
227
+ }
228
+
229
+ template <typename T>
230
+ template <typename U>
231
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
232
+ int comp_result = 0;
233
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
234
+ return comp_result == 0;
235
+ }
236
+
237
+ } // namespace detail
238
+
239
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
240
+ OrtMemoryInfo* p;
241
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
242
+ return MemoryInfo(p);
243
+ }
244
+
245
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
246
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
247
+ }
248
+
249
+ namespace detail {
250
+ template <typename T>
251
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
252
+ AllocatorWithDefaultOptions allocator;
253
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
254
+ }
255
+
256
+ template <typename T>
257
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
258
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
259
+ }
260
+
261
+ template <typename T>
262
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
263
+ AllocatorWithDefaultOptions allocator;
264
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
265
+ }
266
+
267
+ template <typename T>
268
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
269
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
270
+ }
271
+
272
+ template <typename T>
273
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
274
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
275
+ }
276
+
277
+ template <typename T>
278
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
279
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
280
+ }
281
+
282
+ template <typename T>
283
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
284
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
285
+ }
286
+
287
+ template <typename T>
288
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
289
+ GetApi().ClearBoundInputs(this->p_);
290
+ }
291
+
292
+ template <typename T>
293
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
294
+ GetApi().ClearBoundOutputs(this->p_);
295
+ }
296
+
297
+ template <typename T>
298
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
299
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
300
+ }
301
+
302
+ template <typename T>
303
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
304
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
305
+ }
306
+
307
+ namespace binding_utils {
308
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
309
+ std::vector<std::string> result;
310
+ auto free_fn = detail::AllocatedFree(allocator);
311
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
312
+
313
+ char* buffer = nullptr;
314
+ size_t* lengths = nullptr;
315
+ size_t count = 0;
316
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
317
+
318
+ if (count == 0) {
319
+ return result;
320
+ }
321
+
322
+ Ptr buffer_g(buffer, free_fn);
323
+ Ptr lengths_g(lengths, free_fn);
324
+
325
+ result.reserve(count);
326
+ for (size_t i = 0; i < count; ++i) {
327
+ auto sz = *lengths;
328
+ result.emplace_back(buffer, sz);
329
+ buffer += sz;
330
+ ++lengths;
331
+ }
332
+ return result;
333
+ }
334
+
335
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
336
+ std::vector<Value> result;
337
+ size_t owned = 0;
338
+ size_t output_count = 0;
339
+ // Lambda to release the buffer when no longer needed and
340
+ // make sure that we destroy all instances on exception
341
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
342
+ if (buffer) {
343
+ while (owned < output_count) {
344
+ auto* p = buffer + owned++;
345
+ GetApi().ReleaseValue(*p);
346
+ }
347
+ allocator->Free(allocator, buffer);
348
+ }
349
+ };
350
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
351
+
352
+ OrtValue** output_buffer = nullptr;
353
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
354
+ if (output_count == 0) {
355
+ return result;
356
+ }
357
+
358
+ Ptr buffer_g(output_buffer, free_fn);
359
+
360
+ result.reserve(output_count);
361
+ for (size_t i = 0; i < output_count; ++i) {
362
+ result.emplace_back(output_buffer[i]);
363
+ ++owned;
364
+ }
365
+ return result;
366
+ }
367
+
368
+ } // namespace binding_utils
369
+ } // namespace detail
370
+
371
+ inline IoBinding::IoBinding(Session& session) {
372
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
373
+ }
374
+
375
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
376
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
377
+ }
378
+
379
+ inline ThreadingOptions::ThreadingOptions() {
380
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
381
+ }
382
+
383
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
384
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
385
+ return *this;
386
+ }
387
+
388
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
389
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
390
+ return *this;
391
+ }
392
+
393
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
394
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
395
+ return *this;
396
+ }
397
+
398
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
399
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
400
+ return *this;
401
+ }
402
+
403
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
404
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
405
+ return *this;
406
+ }
407
+
408
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
409
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
410
+ return *this;
411
+ }
412
+
413
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
414
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
415
+ return *this;
416
+ }
417
+
418
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
419
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
420
+ if (strcmp(logid, "onnxruntime-node") == 0) {
421
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
422
+ } else {
423
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
424
+ }
425
+ }
426
+
427
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
428
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
429
+ if (strcmp(logid, "onnxruntime-node") == 0) {
430
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
431
+ } else {
432
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
433
+ }
434
+ }
435
+
436
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
437
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
438
+ if (strcmp(logid, "onnxruntime-node") == 0) {
439
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
440
+ } else {
441
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
442
+ }
443
+ }
444
+
445
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
446
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
447
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
448
+ if (strcmp(logid, "onnxruntime-node") == 0) {
449
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
450
+ } else {
451
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
452
+ }
453
+ }
454
+
455
+ inline Env& Env::EnableTelemetryEvents() {
456
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
457
+ return *this;
458
+ }
459
+
460
+ inline Env& Env::DisableTelemetryEvents() {
461
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
462
+ return *this;
463
+ }
464
+
465
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
466
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
467
+ return *this;
468
+ }
469
+
470
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
471
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
472
+ return *this;
473
+ }
474
+
475
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
476
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
477
+ }
478
+
479
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
480
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
481
+ }
482
+
483
+ inline RunOptions::RunOptions() {
484
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
485
+ }
486
+
487
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
488
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
489
+ return *this;
490
+ }
491
+
492
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
493
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
494
+ return *this;
495
+ }
496
+
497
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
498
+ int out;
499
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
500
+ return out;
501
+ }
502
+
503
+ inline int RunOptions::GetRunLogSeverityLevel() const {
504
+ int out;
505
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
506
+ return out;
507
+ }
508
+
509
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
510
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
511
+ return *this;
512
+ }
513
+
514
+ inline const char* RunOptions::GetRunTag() const {
515
+ const char* out;
516
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
517
+ return out;
518
+ }
519
+
520
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
521
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
522
+ return *this;
523
+ }
524
+
525
+ inline RunOptions& RunOptions::SetTerminate() {
526
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
527
+ return *this;
528
+ }
529
+
530
+ inline RunOptions& RunOptions::UnsetTerminate() {
531
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
532
+ return *this;
533
+ }
534
+
535
+ namespace detail {
536
+
537
+ template <typename T>
538
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
539
+ OrtSessionOptions* out;
540
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
541
+ return SessionOptions{out};
542
+ }
543
+
544
+ template <typename T>
545
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
546
+ size_t size = 0;
547
+ // Feed nullptr for the data buffer to query the true size of the string value
548
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
549
+
550
+ std::string out;
551
+ out.resize(size);
552
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
553
+ out.resize(size - 1); // remove the terminating character '\0'
554
+
555
+ return out;
556
+ }
557
+
558
+ template <typename T>
559
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
560
+ int out = 0;
561
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
562
+ return static_cast<bool>(out);
563
+ }
564
+
565
+ template <typename T>
566
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
567
+ if (!this->HasConfigEntry(config_key)) {
568
+ return def;
569
+ }
570
+
571
+ return this->GetConfigEntry(config_key);
572
+ }
573
+
574
+ template <typename T>
575
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
576
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
577
+ return *this;
578
+ }
579
+
580
+ template <typename T>
581
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
582
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
583
+ return *this;
584
+ }
585
+
586
+ template <typename T>
587
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
588
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
589
+ return *this;
590
+ }
591
+
592
+ template <typename T>
593
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
594
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
595
+ return *this;
596
+ }
597
+
598
+ template <typename T>
599
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
600
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
601
+ return *this;
602
+ }
603
+
604
+ template <typename T>
605
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
606
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
607
+ return *this;
608
+ }
609
+
610
+ template <typename T>
611
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
612
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
613
+ return *this;
614
+ }
615
+
616
+ template <typename T>
617
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
618
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
619
+ return *this;
620
+ }
621
+
622
+ template <typename T>
623
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
624
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
625
+ return *this;
626
+ }
627
+
628
+ template <typename T>
629
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
630
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
631
+ return *this;
632
+ }
633
+
634
+ template <typename T>
635
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
636
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
637
+ return *this;
638
+ }
639
+
640
+ template <typename T>
641
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
642
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
643
+ return *this;
644
+ }
645
+
646
+ template <typename T>
647
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
648
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
649
+ return *this;
650
+ }
651
+
652
+ template <typename T>
653
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
654
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
655
+ return *this;
656
+ }
657
+
658
+ template <typename T>
659
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
660
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
661
+ return *this;
662
+ }
663
+
664
+ template <typename T>
665
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
666
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
667
+ return *this;
668
+ }
669
+
670
+ template <typename T>
671
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
672
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
673
+ return *this;
674
+ }
675
+
676
+ template <typename T>
677
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
678
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
679
+ return *this;
680
+ }
681
+
682
+ template <typename T>
683
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
684
+ const std::vector<Value>& ort_values) {
685
+ const size_t inputs_num = names.size();
686
+ if (inputs_num != ort_values.size()) {
687
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
688
+ }
689
+ std::vector<const char*> names_ptr;
690
+ std::vector<const OrtValue*> ort_values_ptrs;
691
+ names_ptr.reserve(inputs_num);
692
+ ort_values_ptrs.reserve(inputs_num);
693
+ for (size_t i = 0; i < inputs_num; ++i) {
694
+ names_ptr.push_back(names[i].c_str());
695
+ ort_values_ptrs.push_back(ort_values[i]);
696
+ }
697
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
698
+ return *this;
699
+ }
700
+
701
+ template <typename T>
702
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
703
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
704
+ return *this;
705
+ }
706
+
707
+ template <typename T>
708
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
709
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
710
+ return *this;
711
+ }
712
+
713
+ template <typename T>
714
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
715
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
716
+ return *this;
717
+ }
718
+
719
+ template <typename T>
720
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
721
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
722
+ return *this;
723
+ }
724
+
725
+ template <typename T>
726
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
727
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
728
+ return *this;
729
+ }
730
+
731
+ template <typename T>
732
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
733
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
734
+ return *this;
735
+ }
736
+
737
+ template <typename T>
738
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
739
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
740
+ return *this;
741
+ }
742
+
743
+ template <typename T>
744
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
745
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
746
+ return *this;
747
+ }
748
+
749
+ template <typename T>
750
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
751
+ const std::string& provider_name,
752
+ const std::unordered_map<std::string, std::string>& provider_options) {
753
+ auto num_entries = provider_options.size();
754
+ std::vector<const char*> keys, values;
755
+ if (num_entries > 0) {
756
+ keys.reserve(num_entries);
757
+ values.reserve(num_entries);
758
+
759
+ for (const auto& entry : provider_options) {
760
+ keys.push_back(entry.first.c_str());
761
+ values.push_back(entry.second.c_str());
762
+ }
763
+ }
764
+
765
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
766
+ keys.data(), values.data(), num_entries));
767
+
768
+ return *this;
769
+ }
770
+
771
+ template <typename T>
772
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
773
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
774
+ return *this;
775
+ }
776
+
777
+ template <typename T>
778
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
779
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
780
+ return *this;
781
+ }
782
+
783
+ template <typename T>
784
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
785
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
786
+ return *this;
787
+ }
788
+
789
+ template <typename T>
790
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
791
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
792
+ return *this;
793
+ }
794
+
795
+ template <typename T>
796
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
797
+ const CustomOpConfigs& custom_op_configs) {
798
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
799
+ // the custom op library.
800
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
801
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
802
+ }
803
+
804
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
805
+ return *this;
806
+ }
807
+
808
+ template <typename T>
809
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
810
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
811
+ return *this;
812
+ }
813
+
814
+ /// Session
815
+ template <typename T>
816
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
817
+ size_t out;
818
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
819
+ return out;
820
+ }
821
+
822
+ template <typename T>
823
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
824
+ size_t out;
825
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
826
+ return out;
827
+ }
828
+
829
+ template <typename T>
830
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
831
+ size_t out;
832
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
833
+ return out;
834
+ }
835
+
836
+ template <typename T>
837
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
838
+ char* out;
839
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
840
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
841
+ }
842
+
843
+ template <typename T>
844
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
845
+ char* out;
846
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
847
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
848
+ }
849
+
850
+ template <typename T>
851
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
852
+ char* out;
853
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
854
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
855
+ }
856
+
857
+ template <typename T>
858
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
859
+ uint64_t out;
860
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
861
+ return out;
862
+ }
863
+
864
+ template <typename T>
865
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
866
+ OrtModelMetadata* out;
867
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
868
+ return ModelMetadata{out};
869
+ }
870
+
871
+ template <typename T>
872
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
873
+ OrtTypeInfo* out;
874
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
875
+ return TypeInfo{out};
876
+ }
877
+
878
+ template <typename T>
879
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
880
+ OrtTypeInfo* out;
881
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
882
+ return TypeInfo{out};
883
+ }
884
+
885
+ template <typename T>
886
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
887
+ OrtTypeInfo* out;
888
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
889
+ return TypeInfo{out};
890
+ }
891
+
892
+ template <typename T>
893
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
894
+ const char* const* output_names, size_t output_count) {
895
+ std::vector<Value> output_values;
896
+ output_values.reserve(output_count);
897
+ for (size_t i = 0; i < output_count; i++)
898
+ output_values.emplace_back(nullptr);
899
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
900
+ return output_values;
901
+ }
902
+
903
+ template <typename T>
904
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
905
+ const char* const* output_names, Value* output_values, size_t output_count) {
906
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
907
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
908
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
909
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
910
+ }
911
+
912
+ template <typename T>
913
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
914
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
915
+ }
916
+
917
+ template <typename T>
918
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
919
+ char* out = nullptr;
920
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
921
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
922
+ }
923
+
924
+ } // namespace detail
925
+
926
+ inline SessionOptions::SessionOptions() {
927
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
928
+ }
929
+
930
+ /// CustomOpConfigs
931
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
932
+ std::string config_key = "custom_op.";
933
+
934
+ config_key += custom_op_name;
935
+ config_key += ".";
936
+ config_key += config;
937
+
938
+ return config_key;
939
+ }
940
+
941
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
942
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
943
+ flat_configs_[full_flat_key] = config_value;
944
+ return *this;
945
+ }
946
+
947
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
948
+ return flat_configs_;
949
+ }
950
+
951
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
952
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
953
+ }
954
+
955
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
956
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
957
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
958
+ }
959
+
960
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
961
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
962
+ }
963
+
964
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
965
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
966
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
967
+ prepacked_weights_container, &this->p_));
968
+ }
969
+
970
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
971
+ char* out;
972
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
973
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
974
+ }
975
+
976
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
977
+ char* out;
978
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
979
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
980
+ }
981
+
982
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
983
+ char* out;
984
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
985
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
986
+ }
987
+
988
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
989
+ char* out;
990
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
991
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
992
+ }
993
+
994
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
995
+ char* out;
996
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
997
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
998
+ }
999
+
1000
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
1001
+ char* out;
1002
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
1003
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
1004
+ }
1005
+
1006
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
1007
+ auto deletor = detail::AllocatedFree(allocator);
1008
+ std::vector<AllocatedStringPtr> result;
1009
+
1010
+ char** out = nullptr;
1011
+ int64_t num_keys = 0;
1012
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1013
+ if (num_keys <= 0) {
1014
+ return result;
1015
+ }
1016
+
1017
+ // array of pointers will be freed
1018
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1019
+ // reserve may throw
1020
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1021
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1022
+ result.reserve(static_cast<size_t>(num_keys));
1023
+ strings_guard.release();
1024
+ for (int64_t i = 0; i < num_keys; ++i) {
1025
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1026
+ }
1027
+
1028
+ return result;
1029
+ }
1030
+
1031
+ inline int64_t ModelMetadata::GetVersion() const {
1032
+ int64_t out;
1033
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1034
+ return out;
1035
+ }
1036
+
1037
+ namespace detail {
1038
+
1039
+ template <typename T>
1040
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1041
+ ONNXTensorElementDataType out;
1042
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1043
+ return out;
1044
+ }
1045
+
1046
+ template <typename T>
1047
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1048
+ size_t out;
1049
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1050
+ return static_cast<size_t>(out);
1051
+ }
1052
+
1053
+ template <typename T>
1054
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1055
+ size_t out;
1056
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1057
+ return out;
1058
+ }
1059
+
1060
+ template <typename T>
1061
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1062
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1063
+ }
1064
+
1065
+ template <typename T>
1066
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1067
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1068
+ }
1069
+
1070
+ template <typename T>
1071
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1072
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1073
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1074
+ return out;
1075
+ }
1076
+
1077
+ template <typename T>
1078
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1079
+ const OrtTensorTypeAndShapeInfo* out;
1080
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1081
+ return ConstTensorTypeAndShapeInfo{out};
1082
+ }
1083
+
1084
+ template <typename T>
1085
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1086
+ const OrtSequenceTypeInfo* out;
1087
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1088
+ return ConstSequenceTypeInfo{out};
1089
+ }
1090
+
1091
+ template <typename T>
1092
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1093
+ const OrtMapTypeInfo* out;
1094
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1095
+ return ConstMapTypeInfo{out};
1096
+ }
1097
+
1098
+ template <typename T>
1099
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1100
+ ONNXType out;
1101
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1102
+ return out;
1103
+ }
1104
+
1105
+ template <typename T>
1106
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1107
+ OrtTypeInfo* output;
1108
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1109
+ return TypeInfo{output};
1110
+ }
1111
+
1112
+ template <typename T>
1113
+ inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
1114
+ OrtTypeInfo* info;
1115
+ ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
1116
+ return TypeInfo{info};
1117
+ }
1118
+
1119
+ template <typename T>
1120
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1121
+ ONNXTensorElementDataType out;
1122
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1123
+ return out;
1124
+ }
1125
+
1126
+ template <typename T>
1127
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1128
+ OrtTypeInfo* output;
1129
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1130
+ return TypeInfo{output};
1131
+ }
1132
+
1133
+ template <typename T>
1134
+ inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
1135
+ const OrtOptionalTypeInfo* info;
1136
+ ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
1137
+ return ConstOptionalTypeInfo{info};
1138
+ }
1139
+
1140
+ } // namespace detail
1141
+
1142
+ namespace detail {
1143
+
1144
+ template <typename T>
1145
+ template <typename R>
1146
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1147
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1148
+ }
1149
+
1150
+ template <typename T>
1151
+ inline bool ConstValueImpl<T>::IsTensor() const {
1152
+ int out;
1153
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1154
+ return out != 0;
1155
+ }
1156
+
1157
+ template <typename T>
1158
+ inline bool ConstValueImpl<T>::HasValue() const {
1159
+ int out;
1160
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1161
+ return out != 0;
1162
+ }
1163
+
1164
+ template <typename T>
1165
+ inline size_t ConstValueImpl<T>::GetCount() const {
1166
+ size_t out;
1167
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1168
+ return out;
1169
+ }
1170
+
1171
+ template <typename T>
1172
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1173
+ OrtValue* out;
1174
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1175
+ return Value{out};
1176
+ }
1177
+
1178
+ template <typename T>
1179
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1180
+ size_t out;
1181
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1182
+ return out;
1183
+ }
1184
+
1185
+ template <typename T>
1186
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1187
+ size_t out;
1188
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1189
+ return out;
1190
+ }
1191
+
1192
+ template <typename T>
1193
+ template <typename R>
1194
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1195
+ R* out;
1196
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1197
+ return out;
1198
+ }
1199
+
1200
+ template <typename T>
1201
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1202
+ void* out;
1203
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1204
+ return out;
1205
+ }
1206
+
1207
+ template <typename T>
1208
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1209
+ OrtTypeInfo* output;
1210
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1211
+ return TypeInfo{output};
1212
+ }
1213
+
1214
+ template <typename T>
1215
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1216
+ OrtTensorTypeAndShapeInfo* output;
1217
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1218
+ return TensorTypeAndShapeInfo{output};
1219
+ }
1220
+
1221
+ template <typename T>
1222
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1223
+ const OrtMemoryInfo* mem_info;
1224
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1225
+ return ConstMemoryInfo(mem_info);
1226
+ }
1227
+
1228
+ template <typename T>
1229
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1230
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1231
+ }
1232
+
1233
+ template <typename T>
1234
+ inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
1235
+ size_t buffer_length;
1236
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
1237
+
1238
+ std::string s;
1239
+ s.resize(buffer_length);
1240
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
1241
+ return s;
1242
+ }
1243
+
1244
+ template <typename T>
1245
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1246
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1247
+ }
1248
+
1249
+ #if !defined(DISABLE_SPARSE_TENSORS)
1250
+ template <typename T>
1251
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1252
+ OrtSparseFormat format;
1253
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1254
+ return format;
1255
+ }
1256
+
1257
+ template <typename T>
1258
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1259
+ OrtTensorTypeAndShapeInfo* output;
1260
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1261
+ return TensorTypeAndShapeInfo{output};
1262
+ }
1263
+
1264
+ template <typename T>
1265
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1266
+ OrtTensorTypeAndShapeInfo* output;
1267
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1268
+ return TensorTypeAndShapeInfo{output};
1269
+ }
1270
+
1271
+ template <typename T>
1272
+ template <typename R>
1273
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1274
+ const void* out;
1275
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1276
+ return reinterpret_cast<const R*>(out);
1277
+ }
1278
+
1279
+ template <typename T>
1280
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1281
+ int out;
1282
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1283
+ return out != 0;
1284
+ }
1285
+
1286
+ template <typename T>
1287
+ template <typename R>
1288
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1289
+ const void* out;
1290
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1291
+ return reinterpret_cast<const R*>(out);
1292
+ }
1293
+
1294
+ #endif
1295
+
1296
+ template <typename T>
1297
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1298
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1299
+ }
1300
+
1301
+ template <typename T>
1302
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1303
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1304
+ }
1305
+
1306
+ template <typename T>
1307
+ inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
1308
+ char* result;
1309
+ ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
1310
+ return result;
1311
+ }
1312
+
1313
+ template <typename T>
1314
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1315
+ void* out;
1316
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1317
+ return out;
1318
+ }
1319
+
1320
+ template <typename T>
1321
+ template <typename R>
1322
+ R* ValueImpl<T>::GetTensorMutableData() {
1323
+ R* out;
1324
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1325
+ return out;
1326
+ }
1327
+
1328
+ template <typename T>
1329
+ template <typename R>
1330
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1331
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1332
+ R* out;
1333
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1334
+ return *out;
1335
+ }
1336
+
1337
+ #if !defined(DISABLE_SPARSE_TENSORS)
1338
+ template <typename T>
1339
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1340
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1341
+ }
1342
+
1343
+ template <typename T>
1344
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1345
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1346
+ }
1347
+
1348
+ template <typename T>
1349
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1350
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1351
+ }
1352
+
1353
+ template <typename T>
1354
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1355
+ const int64_t* indices_data, size_t indices_num) {
1356
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1357
+ values_param.values_shape_len, values_param.data.p_data,
1358
+ indices_data, indices_num));
1359
+ }
1360
+
1361
+ template <typename T>
1362
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1363
+ const OrtSparseValuesParam& values,
1364
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1365
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1366
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1367
+ inner_indices_data, inner_indices_num,
1368
+ outer_indices_data, outer_indices_num));
1369
+ }
1370
+
1371
+ template <typename T>
1372
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1373
+ const OrtSparseValuesParam& values,
1374
+ const Shape& indices_shape,
1375
+ const int32_t* indices_data) {
1376
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1377
+ indices_shape.shape, indices_shape.shape_len,
1378
+ indices_data));
1379
+ }
1380
+
1381
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1382
+
1383
+ } // namespace detail
1384
+
1385
+ template <typename T>
1386
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1387
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1388
+ }
1389
+
1390
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1391
+ ONNXTensorElementDataType type) {
1392
+ OrtValue* out;
1393
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1394
+ return Value{out};
1395
+ }
1396
+
1397
+ template <typename T>
1398
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1399
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1400
+ }
1401
+
1402
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1403
+ OrtValue* out;
1404
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1405
+ return Value{out};
1406
+ }
1407
+
1408
+ #if !defined(DISABLE_SPARSE_TENSORS)
1409
+
1410
+ template <typename T>
1411
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1412
+ const Shape& values_shape) {
1413
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1414
+ }
1415
+
1416
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1417
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1418
+ OrtValue* out;
1419
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1420
+ values_shape.shape, values_shape.shape_len, type, &out));
1421
+ return Value{out};
1422
+ }
1423
+
1424
+ template <typename T>
1425
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1426
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1427
+ }
1428
+
1429
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1430
+ ONNXTensorElementDataType type) {
1431
+ OrtValue* out;
1432
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1433
+ return Value{out};
1434
+ }
1435
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1436
+
1437
+ inline Value Value::CreateMap(Value& keys, Value& values) {
1438
+ OrtValue* out;
1439
+ OrtValue* inputs[2] = {keys, values};
1440
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1441
+ return Value{out};
1442
+ }
1443
+
1444
+ inline Value Value::CreateSequence(std::vector<Value>& values) {
1445
+ OrtValue* out;
1446
+ std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
1447
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1448
+ return Value{out};
1449
+ }
1450
+
1451
+ template <typename T>
1452
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1453
+ OrtValue* out;
1454
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1455
+ return Value{out};
1456
+ }
1457
+
1458
+ //
1459
+ // Custom OP Inlines
1460
+ //
1461
+ inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
1462
+ Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
1463
+ }
1464
+
1465
+ inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
1466
+ return cached_severity_level_;
1467
+ }
1468
+
1469
+ inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
1470
+ const char* func_name, const char* message) const noexcept {
1471
+ OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
1472
+ func_name);
1473
+ return Status{status};
1474
+ }
1475
+
1476
+ // Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
1477
+ // for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
1478
+ // __attribute__(format(printf...)), which does not work with variadic templates.
1479
+ #if defined(__GNUC__)
1480
+ #pragma GCC diagnostic push
1481
+ #pragma GCC diagnostic ignored "-Wformat-nonliteral"
1482
+ #pragma GCC diagnostic ignored "-Wformat-security"
1483
+ #elif defined(__clang__)
1484
+ #pragma clang diagnostic push
1485
+ #pragma clang diagnostic ignored "-Wformat-nonliteral"
1486
+ #pragma clang diagnostic ignored "-Wformat-security"
1487
+ #endif
1488
+ template <typename... Args>
1489
+ inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
1490
+ int line_number, const char* func_name, const char* format,
1491
+ Args&&... args) const noexcept {
1492
+ int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
1493
+
1494
+ if (msg_len < 0) { // Formatting error
1495
+ return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
1496
+ }
1497
+
1498
+ OrtStatus* status = nullptr;
1499
+ const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
1500
+
1501
+ constexpr size_t kStackBufferSize = 1024;
1502
+
1503
+ if (buffer_size < kStackBufferSize) {
1504
+ char buffer[kStackBufferSize];
1505
+ snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
1506
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
1507
+ } else {
1508
+ // std::make_unique is only supported starting at C++14.
1509
+ #if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
1510
+ auto buffer = std::make_unique<char[]>(buffer_size);
1511
+ #else
1512
+ std::unique_ptr<char[]> buffer(new char[buffer_size]);
1513
+ #endif
1514
+ std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
1515
+ status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
1516
+ }
1517
+
1518
+ return Status{status};
1519
+ }
1520
+ // Re-enable -Wformat-nonliteral and -Wformat-security
1521
+ #if defined(__GNUC__)
1522
+ #pragma GCC diagnostic pop
1523
+ #elif defined(__clang__)
1524
+ #pragma clang diagnostic pop
1525
+ #endif
1526
+
1527
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1528
+ }
1529
+
1530
+ inline size_t KernelContext::GetInputCount() const {
1531
+ size_t out = 0;
1532
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1533
+ return out;
1534
+ }
1535
+
1536
+ inline size_t KernelContext::GetOutputCount() const {
1537
+ size_t out = 0;
1538
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1539
+ return out;
1540
+ }
1541
+
1542
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1543
+ const OrtValue* out = nullptr;
1544
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1545
+ return ConstValue{out};
1546
+ }
1547
+
1548
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1549
+ OrtValue* out = nullptr;
1550
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1551
+ return UnownedValue(out);
1552
+ }
1553
+
1554
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1555
+ OrtValue* out = nullptr;
1556
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1557
+ return UnownedValue(out);
1558
+ }
1559
+
1560
+ inline void* KernelContext::GetGPUComputeStream() const {
1561
+ void* out = nullptr;
1562
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1563
+ return out;
1564
+ }
1565
+
1566
+ inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
1567
+ OrtAllocator* out = nullptr;
1568
+ Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
1569
+ return out;
1570
+ }
1571
+
1572
+ inline Logger KernelContext::GetLogger() const {
1573
+ const OrtLogger* out = nullptr;
1574
+ ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
1575
+ return Logger{out};
1576
+ }
1577
+
1578
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1579
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1580
+ }
1581
+
1582
+ namespace detail {
1583
+ template <typename T>
1584
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1585
+ OrtKernelInfo* info_copy = nullptr;
1586
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1587
+ return KernelInfo{info_copy};
1588
+ }
1589
+
1590
+ template <typename T>
1591
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1592
+ size_t out = 0;
1593
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1594
+ return out;
1595
+ }
1596
+
1597
+ template <typename T>
1598
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1599
+ size_t out = 0;
1600
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1601
+ return out;
1602
+ }
1603
+
1604
+ template <typename T>
1605
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1606
+ size_t size = 0;
1607
+
1608
+ // Feed nullptr for the data buffer to query the true size of the string value
1609
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1610
+
1611
+ std::string out;
1612
+ out.resize(size);
1613
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1614
+ out.resize(size - 1); // remove the terminating character '\0'
1615
+
1616
+ return out;
1617
+ }
1618
+
1619
+ template <typename T>
1620
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1621
+ size_t size = 0;
1622
+
1623
+ // Feed nullptr for the data buffer to query the true size of the string value
1624
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1625
+
1626
+ std::string out;
1627
+ out.resize(size);
1628
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1629
+ out.resize(size - 1); // remove the terminating character '\0'
1630
+
1631
+ return out;
1632
+ }
1633
+
1634
+ template <typename T>
1635
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1636
+ OrtTypeInfo* out = nullptr;
1637
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1638
+ return TypeInfo{out};
1639
+ }
1640
+
1641
+ template <typename T>
1642
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1643
+ OrtTypeInfo* out = nullptr;
1644
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1645
+ return TypeInfo{out};
1646
+ }
1647
+
1648
+ template <typename T>
1649
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1650
+ OrtValue* out = nullptr;
1651
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1652
+ return Value{out};
1653
+ }
1654
+
1655
+ template <typename T>
1656
+ inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
1657
+ const OrtValue* out = nullptr;
1658
+ ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
1659
+ return ConstValue{out};
1660
+ }
1661
+
1662
+ template <typename T>
1663
+ inline std::string KernelInfoImpl<T>::GetNodeName() const {
1664
+ size_t size = 0;
1665
+
1666
+ // Feed nullptr for the data buffer to query the true size of the string value
1667
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
1668
+
1669
+ std::string out;
1670
+ out.resize(size);
1671
+ Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
1672
+ out.resize(size - 1); // remove the terminating character '\0'
1673
+
1674
+ return out;
1675
+ }
1676
+
1677
+ template <typename T>
1678
+ inline Logger KernelInfoImpl<T>::GetLogger() const {
1679
+ const OrtLogger* out = nullptr;
1680
+ ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
1681
+ return Logger{out};
1682
+ }
1683
+
1684
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1685
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1686
+ }
1687
+
1688
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1689
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1690
+ }
1691
+
1692
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1693
+ size_t size = 0;
1694
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1695
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1696
+
1697
+ std::string out;
1698
+ out.resize(size);
1699
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1700
+ out.resize(size - 1); // remove the terminating character '\0'
1701
+ out.swap(result);
1702
+ }
1703
+
1704
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1705
+ size_t size = 0;
1706
+ // Feed nullptr for the data buffer to query the true size of the attribute
1707
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1708
+
1709
+ std::vector<float> out;
1710
+ out.resize(size);
1711
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1712
+ out.swap(result);
1713
+ }
1714
+
1715
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1716
+ size_t size = 0;
1717
+
1718
+ // Feed nullptr for the data buffer to query the true size of the attribute
1719
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1720
+
1721
+ std::vector<int64_t> out;
1722
+ out.resize(size);
1723
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1724
+ out.swap(result);
1725
+ }
1726
+ } // namespace detail
1727
+
1728
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1729
+
1730
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1731
+
1732
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1733
+ const char** type_constraint_names,
1734
+ const ONNXTensorElementDataType* type_constraint_values,
1735
+ size_t type_constraint_count,
1736
+ const OpAttr* attr_values, size_t attr_count,
1737
+ size_t input_count, size_t output_count) {
1738
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1739
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1740
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1741
+ OrtOp* op;
1742
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1743
+ static_cast<int>(type_constraint_count),
1744
+ attr_input_values,
1745
+ static_cast<int>(attr_count),
1746
+ static_cast<int>(input_count),
1747
+ static_cast<int>(output_count), &op));
1748
+ return Op{op};
1749
+ }
1750
+
1751
+ inline void Op::Invoke(const OrtKernelContext* context,
1752
+ const Value* input_values,
1753
+ size_t input_count,
1754
+ Value* output_values,
1755
+ size_t output_count) {
1756
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1757
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1758
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1759
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1760
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1761
+ ort_output_values, static_cast<int>(output_count)));
1762
+ }
1763
+
1764
+ inline void Op::Invoke(const OrtKernelContext* context,
1765
+ const OrtValue* const* input_values,
1766
+ size_t input_count,
1767
+ OrtValue* const* output_values,
1768
+ size_t output_count) {
1769
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1770
+ output_values, static_cast<int>(output_count)));
1771
+ }
1772
+
1773
+ inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
1774
+ Ort::ThrowOnError(status);
1775
+ }
1776
+
1777
+ template <>
1778
+ inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1779
+ float out;
1780
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
1781
+ return out;
1782
+ }
1783
+
1784
+ template <>
1785
+ inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1786
+ int64_t out;
1787
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
1788
+ return out;
1789
+ }
1790
+
1791
+ template <>
1792
+ inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1793
+ size_t size = 0;
1794
+ std::string out;
1795
+
1796
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1797
+ OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
1798
+
1799
+ if (status == nullptr) {
1800
+ out.resize(size);
1801
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
1802
+ out.resize(size - 1); // remove the terminating character '\0'
1803
+ } else {
1804
+ Ort::ThrowOnError(status);
1805
+ }
1806
+ return out;
1807
+ }
1808
+
1809
+ template <>
1810
+ inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1811
+ size_t size = 0;
1812
+ std::vector<float> out;
1813
+
1814
+ // Feed nullptr for the data buffer to query the true size of the attribute
1815
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
1816
+
1817
+ if (status == nullptr) {
1818
+ out.resize(size);
1819
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
1820
+ } else {
1821
+ Ort::ThrowOnError(status);
1822
+ }
1823
+ return out;
1824
+ }
1825
+
1826
+ template <>
1827
+ inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1828
+ size_t size = 0;
1829
+ std::vector<int64_t> out;
1830
+
1831
+ // Feed nullptr for the data buffer to query the true size of the attribute
1832
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
1833
+
1834
+ if (status == nullptr) {
1835
+ out.resize(size);
1836
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
1837
+ } else {
1838
+ Ort::ThrowOnError(status);
1839
+ }
1840
+ return out;
1841
+ }
1842
+ inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
1843
+ OrtTensorTypeAndShapeInfo* out;
1844
+ Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
1845
+ return out;
1846
+ }
1847
+
1848
+ inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1849
+ size_t out;
1850
+ Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
1851
+ return out;
1852
+ }
1853
+
1854
+ inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
1855
+ ONNXTensorElementDataType out;
1856
+ Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
1857
+ return out;
1858
+ }
1859
+
1860
+ inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1861
+ size_t out;
1862
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1863
+ return out;
1864
+ }
1865
+
1866
+ inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
1867
+ Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
1868
+ }
1869
+
1870
+ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
1871
+ Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
1872
+ }
1873
+
1874
+ template <typename T>
1875
+ inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
1876
+ T* data;
1877
+ Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
1878
+ return data;
1879
+ }
1880
+
1881
+ inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
1882
+ const OrtMemoryInfo* mem_info;
1883
+ Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
1884
+ return mem_info;
1885
+ }
1886
+
1887
+ template <typename T>
1888
+ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
1889
+ T* data = nullptr;
1890
+ Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
1891
+ return data;
1892
+ }
1893
+
1894
+ inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
1895
+ size_t out;
1896
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1897
+ std::vector<int64_t> output(out);
1898
+ Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
1899
+ return output;
1900
+ }
1901
+
1902
+ inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
1903
+ api_.ReleaseTensorTypeAndShapeInfo(input);
1904
+ }
1905
+
1906
+ inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
1907
+ size_t out;
1908
+ Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
1909
+ return out;
1910
+ }
1911
+
1912
+ inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
1913
+ const OrtValue* out;
1914
+ Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
1915
+ return out;
1916
+ }
1917
+
1918
+ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
1919
+ size_t out;
1920
+ Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
1921
+ return out;
1922
+ }
1923
+
1924
+ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
1925
+ _In_ const int64_t* dim_values, size_t dim_count) {
1926
+ OrtValue* out;
1927
+ Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
1928
+ return out;
1929
+ }
1930
+
1931
+ inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
1932
+ void* out;
1933
+ Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
1934
+ return out;
1935
+ }
1936
+
1937
+ inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
1938
+ _In_ const void* data,
1939
+ _In_ int len,
1940
+ _In_ OrtOpAttrType type) {
1941
+ OrtOpAttr* op_attr{};
1942
+ Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
1943
+ return op_attr;
1944
+ }
1945
+
1946
+ inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
1947
+ api_.ReleaseOpAttr(op_attr);
1948
+ }
1949
+
1950
+ inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
1951
+ _In_z_ const char* op_name,
1952
+ _In_z_ const char* domain,
1953
+ int version,
1954
+ _In_reads_(type_constraint_count) const char** type_constraint_names,
1955
+ _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values,
1956
+ int type_constraint_count,
1957
+ _In_reads_(attr_count) const OrtOpAttr* const* attr_values,
1958
+ int attr_count,
1959
+ int input_count,
1960
+ int output_count) {
1961
+ OrtOp* ort_op{};
1962
+ Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1963
+ type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
1964
+ return ort_op;
1965
+ }
1966
+
1967
+ inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
1968
+ _In_ const OrtOp* ort_op,
1969
+ _In_ const OrtValue* const* input_values,
1970
+ _In_ int input_count,
1971
+ _Inout_ OrtValue* const* output_values,
1972
+ _In_ int output_count) {
1973
+ Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
1974
+ }
1975
+
1976
+ inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
1977
+ api_.ReleaseOp(ort_op);
1978
+ }
1979
+
1980
+ inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
1981
+ OrtKernelInfo* info_copy{};
1982
+ Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
1983
+ return info_copy;
1984
+ }
1985
+
1986
+ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
1987
+ api_.ReleaseKernelInfo(info_copy);
1988
+ }
1989
+
1990
+ inline std::string GetVersionString() {
1991
+ return OrtGetApiBase()->GetVersionString();
1992
+ }
1993
+
1994
+ inline std::string GetBuildInfoString() {
1995
+ return GetApi().GetBuildInfoString();
1996
+ }
1997
+
1998
+ inline std::vector<std::string> GetAvailableProviders() {
1999
+ char** providers;
2000
+ int len;
2001
+
2002
+ auto release_fn = [&len](char** providers) {
2003
+ // This should always return nullptr.
2004
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
2005
+ };
2006
+
2007
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
2008
+ std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
2009
+ std::vector<std::string> available_providers;
2010
+ available_providers.reserve(static_cast<size_t>(len));
2011
+ for (int i = 0; i < len; ++i) {
2012
+ available_providers.emplace_back(providers[i]);
2013
+ }
2014
+ return available_providers;
2015
+ }
2016
+
2017
+ template <typename TOp, typename TKernel>
2018
+ void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
2019
+ ConstSessionOptions options) const {
2020
+ const TOp* derived = static_cast<const TOp*>(this);
2021
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
2022
+
2023
+ out.reserve(keys.size());
2024
+
2025
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
2026
+ const size_t prefix_size = config_entry_key.length();
2027
+
2028
+ for (const auto& key : keys) {
2029
+ config_entry_key.resize(prefix_size);
2030
+ config_entry_key.append(key);
2031
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
2032
+ }
2033
+ }
2034
+
2035
+ } // namespace Ort
1.15.0/onnxruntime.xcframework/Info.plist ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
+ <plist version="1.0">
4
+ <dict>
5
+ <key>AvailableLibraries</key>
6
+ <array>
7
+ <dict>
8
+ <key>LibraryIdentifier</key>
9
+ <string>ios-arm64_x86_64-simulator</string>
10
+ <key>LibraryPath</key>
11
+ <string>onnxruntime.a</string>
12
+ <key>SupportedArchitectures</key>
13
+ <array>
14
+ <string>arm64</string>
15
+ <string>x86_64</string>
16
+ </array>
17
+ <key>SupportedPlatform</key>
18
+ <string>ios</string>
19
+ <key>SupportedPlatformVariant</key>
20
+ <string>simulator</string>
21
+ </dict>
22
+ <dict>
23
+ <key>LibraryIdentifier</key>
24
+ <string>ios-arm64</string>
25
+ <key>LibraryPath</key>
26
+ <string>onnxruntime.a</string>
27
+ <key>SupportedArchitectures</key>
28
+ <array>
29
+ <string>arm64</string>
30
+ </array>
31
+ <key>SupportedPlatform</key>
32
+ <string>ios</string>
33
+ </dict>
34
+ </array>
35
+ <key>CFBundlePackageType</key>
36
+ <string>XFWK</string>
37
+ <key>XCFrameworkFormatVersion</key>
38
+ <string>1.0</string>
39
+ </dict>
40
+ </plist>
1.15.0/onnxruntime.xcframework/ios-arm64/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.15.0/onnxruntime.xcframework/ios-arm64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65cb03cfd4d2354b834e890a7d6acb5ea8566161e8f652fae620ee34ff82415d
3
+ size 58813232
1.15.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/libonnxruntime.a ADDED
@@ -0,0 +1 @@
 
 
1
+ onnxruntime.a
1.15.0/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c97eeaea5528cd6fc485eeea875689eba5e7aa95d28d36e87b6b9e183b5fc98
3
+ size 120135536