csukuangfj commited on
Commit
dcddb04
·
1 Parent(s): afda648

remove unused files

Browse files
AndroidManifest.xml DELETED
@@ -1,11 +0,0 @@
1
- <?xml version="1.0" encoding="utf-8"?>
2
- <manifest xmlns:android="http://schemas.android.com/apk/res/android"
3
- package="ai.onnxruntime"
4
- android:versionCode="11400"
5
- android:versionName="1.14.0" >
6
-
7
- <uses-sdk
8
- android:minSdkVersion="21"
9
- android:targetSdkVersion="24" />
10
-
11
- </manifest>
 
 
 
 
 
 
 
 
 
 
 
 
R.txt DELETED
File without changes
classes.jar DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ce122aaefccba73aef0ca8723189db4ead908d31a25d724d865e9ca81e5a7f6
3
- size 86845
 
 
 
 
headers/cpu_provider_factory.h DELETED
@@ -1,19 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #include "onnxruntime_c_api.h"
5
-
6
- #ifdef __cplusplus
7
- extern "C" {
8
- #endif
9
-
10
- /**
11
- * \param use_arena zero: false. non-zero: true.
12
- */
13
- ORT_EXPORT
14
- ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
- ORT_ALL_ARGS_NONNULL;
16
-
17
- #ifdef __cplusplus
18
- }
19
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
headers/nnapi_provider_factory.h DELETED
@@ -1,62 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
- #pragma once
4
-
5
- #include "onnxruntime_c_api.h"
6
-
7
- // NNAPIFlags are bool options we want to set for NNAPI EP
8
- // This enum is defined as bit flags, and cannot have negative value
9
- // To generate an uint32_t nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
10
- // uint32_t nnapi_flags = 0;
11
- // nnapi_flags |= NNAPI_FLAG_USE_FP16;
12
- enum NNAPIFlags {
13
- NNAPI_FLAG_USE_NONE = 0x000,
14
-
15
- // Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
16
- NNAPI_FLAG_USE_FP16 = 0x001,
17
-
18
- // Use NCHW layout in NNAPI EP, this is only available after Android API level 29
19
- // Please note for now, NNAPI perform worse using NCHW compare to using NHWC
20
- NNAPI_FLAG_USE_NCHW = 0x002,
21
-
22
- // Prevent NNAPI from using CPU devices.
23
- //
24
- // NNAPI is more efficient using GPU or NPU for execution, and NNAPI might fall back to its own CPU implementation
25
- // for operations not supported by GPU/NPU. The CPU implementation of NNAPI (which is called nnapi-reference)
26
- // might be less efficient than the optimized versions of the operation of ORT. It might be advantageous to disable
27
- // the NNAPI CPU fallback and handle execution using ORT kernels.
28
- //
29
- // For some models, if NNAPI would use CPU to execute an operation, and this flag is set, the execution of the
30
- // model may fall back to ORT kernels.
31
- //
32
- // This option is only available after Android API level 29, and will be ignored for Android API level 28-
33
- //
34
- // For NNAPI device assignments, see https://developer.android.com/ndk/guides/neuralnetworks#device-assignment
35
- // For NNAPI CPU fallback, see https://developer.android.com/ndk/guides/neuralnetworks#cpu-fallback
36
- //
37
- // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
38
- // and NNAPI_FLAG_CPU_ONLY flags are set
39
- NNAPI_FLAG_CPU_DISABLED = 0x004,
40
-
41
- // Using CPU only in NNAPI EP, this may decrease the perf but will provide
42
- // reference output value without precision loss, which is useful for validation
43
- //
44
- // Please note, the NNAPI EP will return error status if both NNAPI_FLAG_CPU_DISABLED
45
- // and NNAPI_FLAG_CPU_ONLY flags are set
46
- NNAPI_FLAG_CPU_ONLY = 0x008,
47
-
48
- // Keep NNAPI_FLAG_LAST at the end of the enum definition
49
- // And assign the last NNAPIFlag to it
50
- NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY,
51
- };
52
-
53
- #ifdef __cplusplus
54
- extern "C" {
55
- #endif
56
-
57
- ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi,
58
- _In_ OrtSessionOptions* options, uint32_t nnapi_flags);
59
-
60
- #ifdef __cplusplus
61
- }
62
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
headers/onnxruntime_c_api.h DELETED
The diff for this file is too large to render. See raw diff
 
headers/onnxruntime_cxx_api.h DELETED
@@ -1,1876 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5
- //
6
- // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7
- // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8
- // all the resources follow RAII and do not leak memory.
9
- //
10
- // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11
- // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12
- // until you assign an instance that actually holds an underlying object.
13
- //
14
- // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15
- // Some objects have explicit 'Clone' methods for this purpose.
16
- //
17
- // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18
- // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19
- //
20
- // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21
- //
22
- // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23
- // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
-
25
- #pragma once
26
- #include "onnxruntime_c_api.h"
27
- #include <cstddef>
28
- #include <array>
29
- #include <memory>
30
- #include <stdexcept>
31
- #include <string>
32
- #include <vector>
33
- #include <unordered_map>
34
- #include <utility>
35
- #include <type_traits>
36
-
37
- #ifdef ORT_NO_EXCEPTIONS
38
- #include <iostream>
39
- #endif
40
-
41
- /** \brief All C++ Onnxruntime APIs are defined inside this namespace
42
- *
43
- */
44
- namespace Ort {
45
-
46
- /** \brief All C++ methods that can fail will throw an exception of this type
47
- *
48
- * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
49
- */
50
- struct Exception : std::exception {
51
- Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
52
-
53
- OrtErrorCode GetOrtErrorCode() const { return code_; }
54
- const char* what() const noexcept override { return message_.c_str(); }
55
-
56
- private:
57
- std::string message_;
58
- OrtErrorCode code_;
59
- };
60
-
61
- #ifdef ORT_NO_EXCEPTIONS
62
- // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
63
- // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
64
- #ifndef ORT_CXX_API_THROW
65
- #define ORT_CXX_API_THROW(string, code) \
66
- do { \
67
- std::cerr << Ort::Exception(string, code) \
68
- .what() \
69
- << std::endl; \
70
- abort(); \
71
- } while (false)
72
- #endif
73
- #else
74
- #define ORT_CXX_API_THROW(string, code) \
75
- throw Ort::Exception(string, code)
76
- #endif
77
-
78
- // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
79
- // it's in a template so that we can define a global variable in a header and make
80
- // it transparent to the users of the API.
81
- template <typename T>
82
- struct Global {
83
- static const OrtApi* api_;
84
- };
85
-
86
- // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
87
- template <typename T>
88
- #ifdef ORT_API_MANUAL_INIT
89
- const OrtApi* Global<T>::api_{};
90
- inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
91
-
92
- // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
93
- // required by C++ APIs.
94
- //
95
- // Example mycustomop.cc:
96
- //
97
- // #define ORT_API_MANUAL_INIT
98
- // #include <onnxruntime_cxx_api.h>
99
- // #undef ORT_API_MANUAL_INIT
100
- //
101
- // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
102
- // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
103
- // // ...
104
- // }
105
- //
106
- inline void InitApi(const OrtApi* api) { Global<void>::api_ = api; }
107
- #else
108
- #if defined(_MSC_VER) && !defined(__clang__)
109
- #pragma warning(push)
110
- // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
111
- // Please define ORT_API_MANUAL_INIT if it conerns you.
112
- #pragma warning(disable : 26426)
113
- #endif
114
- const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
115
- #if defined(_MSC_VER) && !defined(__clang__)
116
- #pragma warning(pop)
117
- #endif
118
- #endif
119
-
120
- /// This returns a reference to the OrtApi interface in use
121
- inline const OrtApi& GetApi() { return *Global<void>::api_; }
122
-
123
- /// <summary>
124
- /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
125
- /// returns a vector of strings representing the available execution providers.
126
- /// </summary>
127
- /// <returns>vector of strings</returns>
128
- std::vector<std::string> GetAvailableProviders();
129
-
130
- /** \brief IEEE 754 half-precision floating point data type
131
- * \details It is necessary for type dispatching to make use of C++ API
132
- * The type is implicitly convertible to/from uint16_t.
133
- * The size of the structure should align with uint16_t and one can freely cast
134
- * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
135
- *
136
- * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
137
- * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
138
- * And you can also feed a array of uint16_t elements directly. For example,
139
- *
140
- * \code{.unparsed}
141
- * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
142
- * constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
143
- * std::vector<int64_t> dims = {values_length}; // one dimensional example
144
- * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
145
- * // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
146
- * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
147
- * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
148
- * \endcode
149
- *
150
- * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
151
- * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
152
- * template specialization.
153
- *
154
- * \code{.unparsed}
155
- * namespace yours { struct half {}; } // assume this is your type, define this:
156
- * namespace Ort {
157
- * template<>
158
- * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
159
- * } //namespace Ort
160
- *
161
- * std::vector<yours::half> values;
162
- * std::vector<int64_t> dims = {values.size()}; // one dimensional example
163
- * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
164
- * // Here we are passing element count -> values.size()
165
- * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
166
- *
167
- * \endcode
168
- */
169
- struct Float16_t {
170
- uint16_t value;
171
- constexpr Float16_t() noexcept : value(0) {}
172
- constexpr Float16_t(uint16_t v) noexcept : value(v) {}
173
- constexpr operator uint16_t() const noexcept { return value; }
174
- constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
175
- constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
176
- };
177
-
178
- static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
179
-
180
- /** \brief bfloat16 (Brain Floating Point) data type
181
- * \details It is necessary for type dispatching to make use of C++ API
182
- * The type is implicitly convertible to/from uint16_t.
183
- * The size of the structure should align with uint16_t and one can freely cast
184
- * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
185
- *
186
- * See also code examples for Float16_t above.
187
- */
188
- struct BFloat16_t {
189
- uint16_t value;
190
- constexpr BFloat16_t() noexcept : value(0) {}
191
- constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
192
- constexpr operator uint16_t() const noexcept { return value; }
193
- constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
194
- constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
195
- };
196
-
197
- static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
198
-
199
- namespace detail {
200
- // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
201
- // This can't be done in the C API since C doesn't have function overloading.
202
- #define ORT_DEFINE_RELEASE(NAME) \
203
- inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
204
-
205
- ORT_DEFINE_RELEASE(Allocator);
206
- ORT_DEFINE_RELEASE(MemoryInfo);
207
- ORT_DEFINE_RELEASE(CustomOpDomain);
208
- ORT_DEFINE_RELEASE(ThreadingOptions);
209
- ORT_DEFINE_RELEASE(Env);
210
- ORT_DEFINE_RELEASE(RunOptions);
211
- ORT_DEFINE_RELEASE(Session);
212
- ORT_DEFINE_RELEASE(SessionOptions);
213
- ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
214
- ORT_DEFINE_RELEASE(SequenceTypeInfo);
215
- ORT_DEFINE_RELEASE(MapTypeInfo);
216
- ORT_DEFINE_RELEASE(TypeInfo);
217
- ORT_DEFINE_RELEASE(Value);
218
- ORT_DEFINE_RELEASE(ModelMetadata);
219
- ORT_DEFINE_RELEASE(IoBinding);
220
- ORT_DEFINE_RELEASE(ArenaCfg);
221
- ORT_DEFINE_RELEASE(Status);
222
- ORT_DEFINE_RELEASE(OpAttr);
223
- ORT_DEFINE_RELEASE(Op);
224
- ORT_DEFINE_RELEASE(KernelInfo);
225
-
226
- #undef ORT_DEFINE_RELEASE
227
-
228
- /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
229
- * has no ownership of the underlying C object.
230
- */
231
- template <typename T>
232
- struct Unowned {
233
- using Type = T;
234
- };
235
-
236
- /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
237
- * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
238
- *
239
- * All of the C++ classes
240
- * a) serve as containers for pointers to objects that are created by the underlying C API.
241
- * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
242
- * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
243
- * they would release objects owned automatically when going out of scope, they are move-only.
244
- * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
245
- * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
246
- * such as Onnxruntime or instances of XXXX classes.
247
- * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
248
- * in C++ code.
249
- *
250
- */
251
-
252
- /// <summary>
253
- /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
254
- /// </summary>
255
- template <typename T>
256
- struct Base {
257
- using contained_type = T;
258
-
259
- constexpr Base() = default;
260
- constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
261
- ~Base() { OrtRelease(p_); }
262
-
263
- Base(const Base&) = delete;
264
- Base& operator=(const Base&) = delete;
265
-
266
- Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
267
- Base& operator=(Base&& v) noexcept {
268
- OrtRelease(p_);
269
- p_ = v.release();
270
- return *this;
271
- }
272
-
273
- constexpr operator contained_type*() const noexcept { return p_; }
274
-
275
- /// \brief Relinquishes ownership of the contained C object pointer
276
- /// The underlying object is not destroyed
277
- contained_type* release() {
278
- T* p = p_;
279
- p_ = nullptr;
280
- return p;
281
- }
282
-
283
- protected:
284
- contained_type* p_{};
285
- };
286
-
287
- // Undefined. For const types use Base<Unowned<const T>>
288
- template <typename T>
289
- struct Base<const T>;
290
-
291
- /// <summary>
292
- /// Covers unowned pointers owned by either the ORT
293
- /// or some other instance of CPP wrappers.
294
- /// Used for ConstXXX and UnownedXXXX types that are copyable.
295
- /// Also convenient to wrap raw OrtXX pointers .
296
- /// </summary>
297
- /// <typeparam name="T"></typeparam>
298
- template <typename T>
299
- struct Base<Unowned<T>> {
300
- using contained_type = typename Unowned<T>::Type;
301
-
302
- constexpr Base() = default;
303
- constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
304
-
305
- ~Base() = default;
306
-
307
- Base(const Base&) = default;
308
- Base& operator=(const Base&) = default;
309
-
310
- Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
311
- Base& operator=(Base&& v) noexcept {
312
- p_ = nullptr;
313
- std::swap(p_, v.p_);
314
- return *this;
315
- }
316
-
317
- constexpr operator contained_type*() const noexcept { return p_; }
318
-
319
- protected:
320
- contained_type* p_{};
321
- };
322
-
323
- // Light functor to release memory with OrtAllocator
324
- struct AllocatedFree {
325
- OrtAllocator* allocator_;
326
- explicit AllocatedFree(OrtAllocator* allocator)
327
- : allocator_(allocator) {}
328
- void operator()(void* ptr) const {
329
- if (ptr) allocator_->Free(allocator_, ptr);
330
- }
331
- };
332
-
333
- } // namespace detail
334
-
335
- struct AllocatorWithDefaultOptions;
336
- struct Env;
337
- struct TypeInfo;
338
- struct Value;
339
- struct ModelMetadata;
340
-
341
- /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
342
- * and release them at the end of the scope. The lifespan of the given allocator
343
- * must eclipse the lifespan of AllocatedStringPtr instance
344
- */
345
- using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
346
-
347
- /** \brief The Status that holds ownership of OrtStatus received from C API
348
- * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
349
- * constructors to construct an instance of a Status object from exceptions.
350
- */
351
- struct Status : detail::Base<OrtStatus> {
352
- explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
353
- explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
354
- explicit Status(const Exception&); ///< Creates status instance out of exception
355
- explicit Status(const std::exception&); ///< Creates status instance out of exception
356
- std::string GetErrorMessage() const;
357
- OrtErrorCode GetErrorCode() const;
358
- };
359
-
360
- /** \brief The ThreadingOptions
361
- *
362
- * The ThreadingOptions used for set global threadpools' options of The Env.
363
- */
364
- struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
365
- /// \brief Wraps OrtApi::CreateThreadingOptions
366
- ThreadingOptions();
367
-
368
- /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
369
- ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
370
-
371
- /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
372
- ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
373
-
374
- /// \brief Wraps OrtApi::SetGlobalSpinControl
375
- ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
376
-
377
- /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
378
- ThreadingOptions& SetGlobalDenormalAsZero();
379
-
380
- /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
381
- ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
382
-
383
- /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
384
- ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
385
-
386
- /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
387
- ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
388
- };
389
-
390
- /** \brief The Env (Environment)
391
- *
392
- * The Env holds the logging state used by all other objects.
393
- * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
394
- */
395
- struct Env : detail::Base<OrtEnv> {
396
- explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
397
-
398
- /// \brief Wraps OrtApi::CreateEnv
399
- Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
400
-
401
- /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
402
- Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
403
-
404
- /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
405
- Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
406
-
407
- /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
408
- Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
409
- OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
410
-
411
- /// \brief C Interop Helper
412
- explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
413
-
414
- Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
415
- Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
416
-
417
- Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
418
-
419
- Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
420
- };
421
-
422
- /** \brief Custom Op Domain
423
- *
424
- */
425
- struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
426
- explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
427
-
428
- /// \brief Wraps OrtApi::CreateCustomOpDomain
429
- explicit CustomOpDomain(const char* domain);
430
-
431
- // This does not take ownership of the op, simply registers it.
432
- void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
433
- };
434
-
435
- /** \brief RunOptions
436
- *
437
- */
438
- struct RunOptions : detail::Base<OrtRunOptions> {
439
- explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
440
- RunOptions(); ///< Wraps OrtApi::CreateRunOptions
441
-
442
- RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
443
- int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
444
-
445
- RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
446
- int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
447
-
448
- RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
449
- const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
450
-
451
- RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
452
-
453
- /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
454
- *
455
- * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
456
- * Wraps OrtApi::RunOptionsSetTerminate
457
- */
458
- RunOptions& SetTerminate();
459
-
460
- /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
461
- *
462
- * Wraps OrtApi::RunOptionsUnsetTerminate
463
- */
464
- RunOptions& UnsetTerminate();
465
- };
466
-
467
-
468
- namespace detail {
469
- // Utility function that returns a SessionOption config entry key for a specific custom operator.
470
- // Ex: custom_op.[custom_op_name].[config]
471
- std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
472
- } // namespace detail
473
-
474
- /// <summary>
475
- /// Class that represents session configuration entries for one or more custom operators.
476
- ///
477
- /// Example:
478
- /// Ort::CustomOpConfigs op_configs;
479
- /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
480
- ///
481
- /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
482
- /// </summary>
483
- struct CustomOpConfigs {
484
- CustomOpConfigs() = default;
485
- ~CustomOpConfigs() = default;
486
- CustomOpConfigs(const CustomOpConfigs&) = default;
487
- CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
488
- CustomOpConfigs(CustomOpConfigs&& o) = default;
489
- CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
490
-
491
- /** \brief Adds a session configuration entry/value for a specific custom operator.
492
- *
493
- * \param custom_op_name The name of the custom operator for which to add a configuration entry.
494
- * Must match the name returned by the CustomOp's GetName() method.
495
- * \param config_key The name of the configuration entry.
496
- * \param config_value The value of the configuration entry.
497
- * \return A reference to this object to enable call chaining.
498
- */
499
- CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
500
-
501
- /** \brief Returns a flattened map of custom operator configuration entries and their values.
502
- *
503
- * The keys has been flattened to include both the custom operator name and the configuration entry key name.
504
- * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
505
- * {"my_op.key", "value"}.
506
- *
507
- * \return An unordered map of flattened configurations.
508
- */
509
- const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
510
-
511
- private:
512
- std::unordered_map<std::string, std::string> flat_configs_;
513
- };
514
-
515
- /** \brief Options object used when creating a new Session object
516
- *
517
- * Wraps ::OrtSessionOptions object and methods
518
- */
519
-
520
- struct SessionOptions;
521
-
522
- namespace detail {
523
- // we separate const-only methods because passing const ptr to non-const methods
524
- // is only discovered when inline methods are compiled which is counter-intuitive
525
- template <typename T>
526
- struct ConstSessionOptionsImpl : Base<T> {
527
- using B = Base<T>;
528
- using B::B;
529
-
530
- SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
531
-
532
- std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
533
- bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
534
- std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
535
- };
536
-
537
- template <typename T>
538
- struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
539
- using B = ConstSessionOptionsImpl<T>;
540
- using B::B;
541
-
542
- SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
543
- SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
544
- SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
545
-
546
- SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
547
- SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
548
-
549
- SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
550
-
551
- SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
552
- SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
553
-
554
- SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
555
-
556
- SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
557
- SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
558
-
559
- SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
560
-
561
- SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
562
- SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
563
-
564
- SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
565
-
566
- SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
567
-
568
- SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
569
-
570
- SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
571
- SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
572
-
573
- SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
574
- SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
575
- SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
576
- SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
577
- SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
578
- SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
579
- SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
580
- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
581
- SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
582
- /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
583
- SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
584
- const std::unordered_map<std::string, std::string>& provider_options = {});
585
-
586
- SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
587
- SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
588
- SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
589
-
590
- ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
591
- ///< The custom operator configurations are optional. If provided, custom operator configs are set via
592
- ///< OrtApi::AddSessionConfigEntry.
593
- SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
594
-
595
- SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
596
- };
597
- } // namespace detail
598
-
599
- using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
600
- using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
601
-
602
- /** \brief Wrapper around ::OrtSessionOptions
603
- *
604
- */
605
- struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
606
- explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
607
- SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
608
- explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
609
- UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
610
- ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
611
- };
612
-
613
- /** \brief Wrapper around ::OrtModelMetadata
614
- *
615
- */
616
- struct ModelMetadata : detail::Base<OrtModelMetadata> {
617
- explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
618
- explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
619
-
620
- /** \brief Returns a copy of the producer name.
621
- *
622
- * \param allocator to allocate memory for the copy of the name returned
623
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
624
- * The OrtAllocator instances must be valid at the point of memory release.
625
- */
626
- AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
627
-
628
- /** \brief Returns a copy of the graph name.
629
- *
630
- * \param allocator to allocate memory for the copy of the name returned
631
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
632
- * The OrtAllocator instances must be valid at the point of memory release.
633
- */
634
- AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
635
-
636
- /** \brief Returns a copy of the domain name.
637
- *
638
- * \param allocator to allocate memory for the copy of the name returned
639
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
640
- * The OrtAllocator instances must be valid at the point of memory release.
641
- */
642
- AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
643
-
644
- /** \brief Returns a copy of the description.
645
- *
646
- * \param allocator to allocate memory for the copy of the string returned
647
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
648
- * The OrtAllocator instances must be valid at the point of memory release.
649
- */
650
- AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
651
-
652
- /** \brief Returns a copy of the graph description.
653
- *
654
- * \param allocator to allocate memory for the copy of the string returned
655
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
656
- * The OrtAllocator instances must be valid at the point of memory release.
657
- */
658
- AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
659
-
660
- /** \brief Returns a vector of copies of the custom metadata keys.
661
- *
662
- * \param allocator to allocate memory for the copy of the string returned
663
- * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
664
- * The OrtAllocator instance must be valid at the point of memory release.
665
- */
666
- std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
667
-
668
- /** \brief Looks up a value by a key in the Custom Metadata map
669
- *
670
- * \param key zero terminated string key to lookup
671
- * \param allocator to allocate memory for the copy of the string returned
672
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
673
- * maybe nullptr if key is not found.
674
- *
675
- * The OrtAllocator instances must be valid at the point of memory release.
676
- */
677
- AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
678
-
679
- int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
680
- };
681
-
682
- struct IoBinding;
683
-
684
- namespace detail {
685
-
686
- // we separate const-only methods because passing const ptr to non-const methods
687
- // is only discovered when inline methods are compiled which is counter-intuitive
688
- template <typename T>
689
- struct ConstSessionImpl : Base<T> {
690
- using B = Base<T>;
691
- using B::B;
692
-
693
- size_t GetInputCount() const; ///< Returns the number of model inputs
694
- size_t GetOutputCount() const; ///< Returns the number of model outputs
695
- size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
696
-
697
- /** \brief Returns a copy of input name at the specified index.
698
- *
699
- * \param index must less than the value returned by GetInputCount()
700
- * \param allocator to allocate memory for the copy of the name returned
701
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
702
- * The OrtAllocator instances must be valid at the point of memory release.
703
- */
704
- AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
705
-
706
- /** \brief Returns a copy of output name at then specified index.
707
- *
708
- * \param index must less than the value returned by GetOutputCount()
709
- * \param allocator to allocate memory for the copy of the name returned
710
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
711
- * The OrtAllocator instances must be valid at the point of memory release.
712
- */
713
- AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
714
-
715
- /** \brief Returns a copy of the overridable initializer name at then specified index.
716
- *
717
- * \param index must less than the value returned by GetOverridableInitializerCount()
718
- * \param allocator to allocate memory for the copy of the name returned
719
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
720
- * The OrtAllocator instances must be valid at the point of memory release.
721
- */
722
- AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
723
-
724
- uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
725
- ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
726
-
727
- TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
728
- TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
729
- TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
730
- };
731
-
732
- template <typename T>
733
- struct SessionImpl : ConstSessionImpl<T> {
734
- using B = ConstSessionImpl<T>;
735
- using B::B;
736
-
737
- /** \brief Run the model returning results in an Ort allocated vector.
738
- *
739
- * Wraps OrtApi::Run
740
- *
741
- * The caller provides a list of inputs and a list of the desired outputs to return.
742
- *
743
- * See the output logs for more information on warnings/errors that occur while processing the model.
744
- * Common errors are.. (TODO)
745
- *
746
- * \param[in] run_options
747
- * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
748
- * \param[in] input_values Array of Value objects of length input_count that is the list of input values
749
- * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
750
- * \param[in] output_names Array of C style strings of length output_count that is the list of output names
751
- * \param[in] output_count Number of outputs (the size of the output_names array)
752
- * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
753
- */
754
- std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
755
- const char* const* output_names, size_t output_count);
756
-
757
- /** \brief Run the model returning results in user provided outputs
758
- * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
759
- */
760
- void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
761
- const char* const* output_names, Value* output_values, size_t output_count);
762
-
763
- void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
764
-
765
- /** \brief End profiling and return a copy of the profiling file name.
766
- *
767
- * \param allocator to allocate memory for the copy of the string returned
768
- * \return a instance of smart pointer that would deallocate the buffer when out of scope.
769
- * The OrtAllocator instances must be valid at the point of memory release.
770
- */
771
- AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
772
- };
773
-
774
- } // namespace detail
775
-
776
- using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
777
- using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
778
-
779
- /** \brief Wrapper around ::OrtSession
780
- *
781
- */
782
- struct Session : detail::SessionImpl<OrtSession> {
783
- explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
784
- Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
785
- Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
786
- OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
787
- Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
788
- Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
789
- OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
790
-
791
- ConstSession GetConst() const { return ConstSession{this->p_}; }
792
- UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
793
- };
794
-
795
- namespace detail {
796
- template <typename T>
797
- struct MemoryInfoImpl : Base<T> {
798
- using B = Base<T>;
799
- using B::B;
800
-
801
- std::string GetAllocatorName() const;
802
- OrtAllocatorType GetAllocatorType() const;
803
- int GetDeviceId() const;
804
- OrtMemoryInfoDeviceType GetDeviceType() const;
805
- OrtMemType GetMemoryType() const;
806
-
807
- template <typename U>
808
- bool operator==(const MemoryInfoImpl<U>& o) const;
809
- };
810
- } // namespace detail
811
-
812
- // Const object holder that does not own the underlying object
813
- using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
814
-
815
- /** \brief Wrapper around ::OrtMemoryInfo
816
- *
817
- */
818
- struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
819
- static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
820
- explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
821
- explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
822
- MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
823
- ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
824
- };
825
-
826
- namespace detail {
827
- template <typename T>
828
- struct TensorTypeAndShapeInfoImpl : Base<T> {
829
- using B = Base<T>;
830
- using B::B;
831
-
832
- ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
833
- size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
834
-
835
- size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
836
-
837
- /** \deprecated use GetShape() returning std::vector
838
- * [[deprecated]]
839
- * This interface is unsafe to use
840
- */
841
- [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
842
-
843
- void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
844
-
845
- std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
846
- };
847
-
848
- } // namespace detail
849
-
850
- using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
851
-
852
- /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
853
- *
854
- */
855
- struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
856
- explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
857
- explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
858
- ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
859
- };
860
-
861
- namespace detail {
862
- template <typename T>
863
- struct SequenceTypeInfoImpl : Base<T> {
864
- using B = Base<T>;
865
- using B::B;
866
- TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
867
- };
868
-
869
- } // namespace detail
870
-
871
- using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
872
-
873
- /** \brief Wrapper around ::OrtSequenceTypeInfo
874
- *
875
- */
876
- struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
877
- explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
878
- explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
879
- ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
880
- };
881
-
882
- namespace detail {
883
- template <typename T>
884
- struct MapTypeInfoImpl : detail::Base<T> {
885
- using B = Base<T>;
886
- using B::B;
887
- ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
888
- TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
889
- };
890
-
891
- } // namespace detail
892
-
893
- using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
894
-
895
- /** \brief Wrapper around ::OrtMapTypeInfo
896
- *
897
- */
898
- struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
899
- explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
900
- explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
901
- ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
902
- };
903
-
904
- namespace detail {
905
- template <typename T>
906
- struct TypeInfoImpl : detail::Base<T> {
907
- using B = Base<T>;
908
- using B::B;
909
-
910
- ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
911
- ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
912
- ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
913
-
914
- ONNXType GetONNXType() const;
915
- };
916
- } // namespace detail
917
-
918
- /// <summary>
919
- /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
920
- /// Provides access to const OrtTypeInfo APIs.
921
- /// </summary>
922
- using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
923
-
924
- /// <summary>
925
- /// Type information that may contain either TensorTypeAndShapeInfo or
926
- /// the information about contained sequence or map depending on the ONNXType.
927
- /// </summary>
928
- struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
929
- explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
930
- explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
931
-
932
- ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
933
- };
934
-
935
- namespace detail {
936
- // This structure is used to feed sparse tensor values
937
- // information for use with FillSparseTensor<Format>() API
938
- // if the data type for the sparse tensor values is numeric
939
- // use data.p_data, otherwise, use data.str pointer to feed
940
- // values. data.str is an array of const char* that are zero terminated.
941
- // number of strings in the array must match shape size.
942
- // For fully sparse tensors use shape {0} and set p_data/str
943
- // to nullptr.
944
- struct OrtSparseValuesParam {
945
- const int64_t* values_shape;
946
- size_t values_shape_len;
947
- union {
948
- const void* p_data;
949
- const char** str;
950
- } data;
951
- };
952
-
953
- // Provides a way to pass shape in a single
954
- // argument
955
- struct Shape {
956
- const int64_t* shape;
957
- size_t shape_len;
958
- };
959
-
960
- template <typename T>
961
- struct ConstValueImpl : Base<T> {
962
- using B = Base<T>;
963
- using B::B;
964
-
965
- /// <summary>
966
- /// Obtains a pointer to a user defined data for experimental purposes
967
- /// </summary>
968
- template <typename R>
969
- void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
970
-
971
- bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
972
- bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
973
-
974
- size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
975
- Value GetValue(int index, OrtAllocator* allocator) const;
976
-
977
- /// <summary>
978
- /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
979
- /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
980
- /// for allocating necessary memory and calling GetStringTensorContent().
981
- /// </summary>
982
- /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
983
- size_t GetStringTensorDataLength() const;
984
-
985
- /// <summary>
986
- /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
987
- /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
988
- /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
989
- /// strings.
990
- ///
991
- /// Strings are always assumed to be on CPU, no X-device copy.
992
- /// </summary>
993
- /// <param name="buffer">user allocated buffer</param>
994
- /// <param name="buffer_length">length in bytes of the allocated buffer</param>
995
- /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
996
- /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
997
- /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
998
- /// for sparse tensors</param>
999
- void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1000
-
1001
- /// <summary>
1002
- /// Returns a const typed pointer to the tensor contained data.
1003
- /// No type checking is performed, the caller must ensure the type matches the tensor type.
1004
- /// </summary>
1005
- /// <typeparam name="T"></typeparam>
1006
- /// <returns>const pointer to data, no copies made</returns>
1007
- template <typename R>
1008
- const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1009
-
1010
- /// <summary>
1011
- /// Returns a non-typed pointer to a tensor contained data.
1012
- /// </summary>
1013
- /// <returns>const pointer to data, no copies made</returns>
1014
- const void* GetTensorRawData() const;
1015
-
1016
- /// <summary>
1017
- /// The API returns type information for data contained in a tensor. For sparse
1018
- /// tensors it returns type information for contained non-zero values.
1019
- /// It returns dense shape for sparse tensors.
1020
- /// </summary>
1021
- /// <returns>TypeInfo</returns>
1022
- TypeInfo GetTypeInfo() const;
1023
-
1024
- /// <summary>
1025
- /// The API returns type information for data contained in a tensor. For sparse
1026
- /// tensors it returns type information for contained non-zero values.
1027
- /// It returns dense shape for sparse tensors.
1028
- /// </summary>
1029
- /// <returns>TensorTypeAndShapeInfo</returns>
1030
- TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1031
-
1032
- /// <summary>
1033
- /// This API returns information about the memory allocation used to hold data.
1034
- /// </summary>
1035
- /// <returns>Non owning instance of MemoryInfo</returns>
1036
- ConstMemoryInfo GetTensorMemoryInfo() const;
1037
-
1038
- /// <summary>
1039
- /// The API copies UTF-8 encoded bytes for the requested string element
1040
- /// contained within a tensor or a sparse tensor into a provided buffer.
1041
- /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1042
- /// </summary>
1043
- /// <param name="buffer_length"></param>
1044
- /// <param name="element_index"></param>
1045
- /// <param name="buffer"></param>
1046
- void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1047
-
1048
- /// <summary>
1049
- /// The API returns a byte length of UTF-8 encoded string element
1050
- /// contained in either a tensor or a spare tensor values.
1051
- /// </summary>
1052
- /// <param name="element_index"></param>
1053
- /// <returns>byte length for the specified string element</returns>
1054
- size_t GetStringTensorElementLength(size_t element_index) const;
1055
-
1056
- #if !defined(DISABLE_SPARSE_TENSORS)
1057
- /// <summary>
1058
- /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1059
- /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1060
- /// the value returned is ORT_SPARSE_UNDEFINED.
1061
- /// </summary>
1062
- /// <returns>Format enum</returns>
1063
- OrtSparseFormat GetSparseFormat() const;
1064
-
1065
- /// <summary>
1066
- /// The API returns type and shape information for stored non-zero values of the
1067
- /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1068
- /// </summary>
1069
- /// <returns>TensorTypeAndShapeInfo values information</returns>
1070
- TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1071
-
1072
- /// <summary>
1073
- /// The API returns type and shape information for the specified indices. Each supported
1074
- /// indices have their own enum values even if a give format has more than one kind of indices.
1075
- /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1076
- /// </summary>
1077
- /// <param name="format">enum requested</param>
1078
- /// <returns>type and shape information</returns>
1079
- TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1080
-
1081
- /// <summary>
1082
- /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1083
- /// a convenience data type casting on the return type pointer. Make sure you are requesting
1084
- /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1085
- /// </summary>
1086
- /// <typeparam name="T">type to cast to</typeparam>
1087
- /// <param name="indices_format">requested indices kind</param>
1088
- /// <param name="num_indices">number of indices entries</param>
1089
- /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1090
- template <typename R>
1091
- const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1092
-
1093
- /// <summary>
1094
- /// Returns true if the OrtValue contains a sparse tensor
1095
- /// </summary>
1096
- /// <returns></returns>
1097
- bool IsSparseTensor() const;
1098
-
1099
- /// <summary>
1100
- /// The API returns a pointer to an internal buffer of the sparse tensor
1101
- /// containing non-zero values. The API merely does casting. Make sure you
1102
- /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1103
- /// first.
1104
- /// </summary>
1105
- /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1106
- /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1107
- template <typename R>
1108
- const R* GetSparseTensorValues() const;
1109
-
1110
- #endif
1111
- };
1112
-
1113
- template <typename T>
1114
- struct ValueImpl : ConstValueImpl<T> {
1115
- using B = ConstValueImpl<T>;
1116
- using B::B;
1117
-
1118
- /// <summary>
1119
- /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1120
- /// No type checking is performed, the caller must ensure the type matches the tensor type.
1121
- /// </summary>
1122
- /// <returns>non-const pointer to data, no copies made</returns>
1123
- template <typename R>
1124
- R* GetTensorMutableData();
1125
-
1126
- /// <summary>
1127
- /// Returns a non-typed non-const pointer to a tensor contained data.
1128
- /// </summary>
1129
- /// <returns>pointer to data, no copies made</returns>
1130
- void* GetTensorMutableRawData();
1131
-
1132
- /// <summary>
1133
- // Obtain a reference to an element of data at the location specified
1134
- /// by the vector of dims.
1135
- /// </summary>
1136
- /// <typeparam name="R"></typeparam>
1137
- /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1138
- /// <returns></returns>
1139
- template <typename R>
1140
- R& At(const std::vector<int64_t>& location);
1141
-
1142
- /// <summary>
1143
- /// Set all strings at once in a string tensor
1144
- /// </summary>
1145
- /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1146
- /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1147
- void FillStringTensor(const char* const* s, size_t s_len);
1148
-
1149
- /// <summary>
1150
- /// Set a single string in a string tensor
1151
- /// </summary>
1152
- /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1153
- /// <param name="index">[in] Index of the string in the tensor to set</param>
1154
- void FillStringTensorElement(const char* s, size_t index);
1155
-
1156
- #if !defined(DISABLE_SPARSE_TENSORS)
1157
- /// <summary>
1158
- /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1159
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1160
- /// allocated buffers lifespan must eclipse that of the OrtValue.
1161
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1162
- /// </summary>
1163
- /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1164
- /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1165
- void UseCooIndices(int64_t* indices_data, size_t indices_num);
1166
-
1167
- /// <summary>
1168
- /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1169
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1170
- /// allocated buffers lifespan must eclipse that of the OrtValue.
1171
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1172
- /// </summary>
1173
- /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1174
- /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1175
- /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1176
- /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1177
- void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1178
-
1179
- /// <summary>
1180
- /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1181
- /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1182
- /// allocated buffers lifespan must eclipse that of the OrtValue.
1183
- /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1184
- /// </summary>
1185
- /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1186
- /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1187
- void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1188
-
1189
- /// <summary>
1190
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1191
- /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1192
- /// at difference device than the allocator, a X-device copy will be performed if possible.
1193
- /// </summary>
1194
- /// <param name="data_mem_info">specified buffer memory description</param>
1195
- /// <param name="values_param">values buffer information.</param>
1196
- /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1197
- /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1198
- void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1199
- const int64_t* indices_data, size_t indices_num);
1200
-
1201
- /// <summary>
1202
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1203
- /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1204
- /// at difference device than the allocator, a X-device copy will be performed if possible.
1205
- /// </summary>
1206
- /// <param name="data_mem_info">specified buffer memory description</param>
1207
- /// <param name="values">values buffer information</param>
1208
- /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1209
- /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1210
- /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1211
- /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1212
- void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1213
- const OrtSparseValuesParam& values,
1214
- const int64_t* inner_indices_data, size_t inner_indices_num,
1215
- const int64_t* outer_indices_data, size_t outer_indices_num);
1216
-
1217
- /// <summary>
1218
- /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1219
- /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1220
- /// at difference device than the allocator, a X-device copy will be performed if possible.
1221
- /// </summary>
1222
- /// <param name="data_mem_info">specified buffer memory description</param>
1223
- /// <param name="values">values buffer information</param>
1224
- /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1225
- /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1226
- void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1227
- const OrtSparseValuesParam& values,
1228
- const Shape& indices_shape,
1229
- const int32_t* indices_data);
1230
-
1231
- #endif
1232
- };
1233
-
1234
- } // namespace detail
1235
-
1236
- using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
1237
- using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
1238
-
1239
- /** \brief Wrapper around ::OrtValue
1240
- *
1241
- */
1242
- struct Value : detail::ValueImpl<OrtValue> {
1243
- using Base = detail::ValueImpl<OrtValue>;
1244
- using OrtSparseValuesParam = detail::OrtSparseValuesParam;
1245
- using Shape = detail::Shape;
1246
-
1247
- explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1248
- explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1249
- Value(Value&&) = default;
1250
- Value& operator=(Value&&) = default;
1251
-
1252
- ConstValue GetConst() const { return ConstValue{this->p_}; }
1253
- UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1254
-
1255
- /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1256
- * \tparam T The numeric datatype. This API is not suitable for strings.
1257
- * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1258
- * \param p_data Pointer to the data buffer.
1259
- * \param p_data_element_count The number of elements in the data buffer.
1260
- * \param shape Pointer to the tensor shape dimensions.
1261
- * \param shape_len The number of tensor shape dimensions.
1262
- */
1263
- template <typename T>
1264
- static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1265
-
1266
- /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1267
- * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1268
- * \param p_data Pointer to the data buffer.
1269
- * \param p_data_byte_count The number of bytes in the data buffer.
1270
- * \param shape Pointer to the tensor shape dimensions.
1271
- * \param shape_len The number of tensor shape dimensions.
1272
- * \param type The data type.
1273
- */
1274
- static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1275
- ONNXTensorElementDataType type);
1276
-
1277
- /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1278
- * \tparam T The numeric datatype. This API is not suitable for strings.
1279
- * \param allocator The allocator to use.
1280
- * \param shape Pointer to the tensor shape dimensions.
1281
- * \param shape_len The number of tensor shape dimensions.
1282
- */
1283
- template <typename T>
1284
- static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1285
-
1286
- /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1287
- * \param allocator The allocator to use.
1288
- * \param shape Pointer to the tensor shape dimensions.
1289
- * \param shape_len The number of tensor shape dimensions.
1290
- * \param type The data type.
1291
- */
1292
- static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1293
-
1294
- static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
1295
- static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1296
-
1297
- template <typename T>
1298
- static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
1299
-
1300
- #if !defined(DISABLE_SPARSE_TENSORS)
1301
- /// <summary>
1302
- /// This is a simple forwarding method to the other overload that helps deducing
1303
- /// data type enum value from the type of the buffer.
1304
- /// </summary>
1305
- /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1306
- /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1307
- /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1308
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
1309
- /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1310
- /// <returns></returns>
1311
- template <typename T>
1312
- static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1313
- const Shape& values_shape);
1314
-
1315
- /// <summary>
1316
- /// Creates an OrtValue instance containing SparseTensor. This constructs
1317
- /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1318
- /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1319
- /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1320
- /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1321
- /// to supply a sparse format specific indices.
1322
- /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1323
- /// can be properly copied into the allocated buffer.
1324
- /// </summary>
1325
- /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1326
- /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1327
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
1328
- /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1329
- /// <param name="type">data type</param>
1330
- /// <returns>Ort::Value instance containing SparseTensor</returns>
1331
- static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1332
- const Shape& values_shape, ONNXTensorElementDataType type);
1333
-
1334
- /// <summary>
1335
- /// This is a simple forwarding method to the below CreateSparseTensor.
1336
- /// This helps to specify data type enum in terms of C++ data type.
1337
- /// Use CreateSparseTensor<T>
1338
- /// </summary>
1339
- /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1340
- /// <param name="allocator">allocator to use</param>
1341
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
1342
- /// <returns>Ort::Value</returns>
1343
- template <typename T>
1344
- static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1345
-
1346
- /// <summary>
1347
- /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1348
- /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1349
- /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1350
- /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1351
- /// strings.
1352
- /// </summary>
1353
- /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1354
- /// <param name="dense_shape">a would be dense shape of the tensor</param>
1355
- /// <param name="type">data type</param>
1356
- /// <returns>an instance of Ort::Value</returns>
1357
- static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1358
-
1359
- #endif // !defined(DISABLE_SPARSE_TENSORS)
1360
- };
1361
-
1362
- /// <summary>
1363
- /// Represents native memory allocation coming from one of the
1364
- /// OrtAllocators registered with OnnxRuntime.
1365
- /// Use it to wrap an allocation made by an allocator
1366
- /// so it can be automatically released when no longer needed.
1367
- /// </summary>
1368
- struct MemoryAllocation {
1369
- MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1370
- ~MemoryAllocation();
1371
- MemoryAllocation(const MemoryAllocation&) = delete;
1372
- MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1373
- MemoryAllocation(MemoryAllocation&&) noexcept;
1374
- MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1375
-
1376
- void* get() { return p_; }
1377
- size_t size() const { return size_; }
1378
-
1379
- private:
1380
- OrtAllocator* allocator_;
1381
- void* p_;
1382
- size_t size_;
1383
- };
1384
-
1385
- namespace detail {
1386
- template <typename T>
1387
- struct AllocatorImpl : Base<T> {
1388
- using B = Base<T>;
1389
- using B::B;
1390
-
1391
- void* Alloc(size_t size);
1392
- MemoryAllocation GetAllocation(size_t size);
1393
- void Free(void* p);
1394
- ConstMemoryInfo GetInfo() const;
1395
- };
1396
-
1397
- } // namespace detail
1398
-
1399
- /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1400
- *
1401
- */
1402
- struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1403
- explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1404
- AllocatorWithDefaultOptions();
1405
- };
1406
-
1407
- /** \brief Wrapper around ::OrtAllocator
1408
- *
1409
- */
1410
- struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1411
- explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1412
- Allocator(const Session& session, const OrtMemoryInfo*);
1413
- };
1414
-
1415
- using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1416
-
1417
- namespace detail {
1418
- namespace binding_utils {
1419
- // Bring these out of template
1420
- std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1421
- std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1422
- } // namespace binding_utils
1423
-
1424
- template <typename T>
1425
- struct ConstIoBindingImpl : Base<T> {
1426
- using B = Base<T>;
1427
- using B::B;
1428
-
1429
- std::vector<std::string> GetOutputNames() const;
1430
- std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1431
- std::vector<Value> GetOutputValues() const;
1432
- std::vector<Value> GetOutputValues(OrtAllocator*) const;
1433
- };
1434
-
1435
- template <typename T>
1436
- struct IoBindingImpl : ConstIoBindingImpl<T> {
1437
- using B = ConstIoBindingImpl<T>;
1438
- using B::B;
1439
-
1440
- void BindInput(const char* name, const Value&);
1441
- void BindOutput(const char* name, const Value&);
1442
- void BindOutput(const char* name, const OrtMemoryInfo*);
1443
- void ClearBoundInputs();
1444
- void ClearBoundOutputs();
1445
- void SynchronizeInputs();
1446
- void SynchronizeOutputs();
1447
- };
1448
-
1449
- } // namespace detail
1450
-
1451
- using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
1452
- using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
1453
-
1454
- /** \brief Wrapper around ::OrtIoBinding
1455
- *
1456
- */
1457
- struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1458
- explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1459
- explicit IoBinding(Session& session);
1460
- ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1461
- UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1462
- };
1463
-
1464
- /*! \struct Ort::ArenaCfg
1465
- * \brief it is a structure that represents the configuration of an arena based allocator
1466
- * \details Please see docs/C_API.md for details
1467
- */
1468
- struct ArenaCfg : detail::Base<OrtArenaCfg> {
1469
- explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1470
- /**
1471
- * Wraps OrtApi::CreateArenaCfg
1472
- * \param max_mem - use 0 to allow ORT to choose the default
1473
- * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1474
- * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1475
- * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1476
- * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1477
- */
1478
- ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1479
- };
1480
-
1481
- //
1482
- // Custom OPs (only needed to implement custom OPs)
1483
- //
1484
-
1485
- /// <summary>
1486
- /// This struct provides life time management for custom op attribute
1487
- /// </summary>
1488
- struct OpAttr : detail::Base<OrtOpAttr> {
1489
- OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1490
- };
1491
-
1492
- /// <summary>
1493
- /// This class wraps a raw pointer OrtKernelContext* that is being passed
1494
- /// to the custom kernel Compute() method. Use it to safely access context
1495
- /// attributes, input and output parameters with exception safety guarantees.
1496
- /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
1497
- /// </summary>
1498
- struct KernelContext {
1499
- explicit KernelContext(OrtKernelContext* context);
1500
- size_t GetInputCount() const;
1501
- size_t GetOutputCount() const;
1502
- ConstValue GetInput(size_t index) const;
1503
- UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
1504
- UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
1505
- void* GetGPUComputeStream() const;
1506
-
1507
- private:
1508
- OrtKernelContext* ctx_;
1509
- };
1510
-
1511
- struct KernelInfo;
1512
-
1513
- namespace detail {
1514
- namespace attr_utils {
1515
- void GetAttr(const OrtKernelInfo* p, const char* name, float&);
1516
- void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
1517
- void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
1518
- void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
1519
- void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
1520
- } // namespace attr_utils
1521
-
1522
- template <typename T>
1523
- struct KernelInfoImpl : Base<T> {
1524
- using B = Base<T>;
1525
- using B::B;
1526
-
1527
- KernelInfo Copy() const;
1528
-
1529
- template <typename R> // R is only implemented for float, int64_t, and string
1530
- R GetAttribute(const char* name) const {
1531
- R val;
1532
- attr_utils::GetAttr(this->p_, name, val);
1533
- return val;
1534
- }
1535
-
1536
- template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
1537
- std::vector<R> GetAttributes(const char* name) const {
1538
- std::vector<R> result;
1539
- attr_utils::GetAttrs(this->p_, name, result);
1540
- return result;
1541
- }
1542
-
1543
- Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
1544
-
1545
- size_t GetInputCount() const;
1546
- size_t GetOutputCount() const;
1547
-
1548
- std::string GetInputName(size_t index) const;
1549
- std::string GetOutputName(size_t index) const;
1550
-
1551
- TypeInfo GetInputTypeInfo(size_t index) const;
1552
- TypeInfo GetOutputTypeInfo(size_t index) const;
1553
- };
1554
-
1555
- } // namespace detail
1556
-
1557
- using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
1558
-
1559
- /// <summary>
1560
- /// This struct owns the OrtKernInfo* pointer when a copy is made.
1561
- /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
1562
- /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
1563
- /// so it does not destroy the pointer the kernel does not own.
1564
- /// </summary>
1565
- struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
1566
- explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
1567
- explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
1568
- ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
1569
- };
1570
-
1571
- /// <summary>
1572
- /// Create and own custom defined operation.
1573
- /// </summary>
1574
- struct Op : detail::Base<OrtOp> {
1575
- explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
1576
-
1577
- explicit Op(OrtOp*); ///< Take ownership of the OrtOp
1578
-
1579
- static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
1580
- int version, const char** type_constraint_names,
1581
- const ONNXTensorElementDataType* type_constraint_values,
1582
- size_t type_constraint_count,
1583
- const OpAttr* attr_values,
1584
- size_t attr_count,
1585
- size_t input_count, size_t output_count);
1586
-
1587
- void Invoke(const OrtKernelContext* context,
1588
- const Value* input_values,
1589
- size_t input_count,
1590
- Value* output_values,
1591
- size_t output_count);
1592
-
1593
- // For easier refactoring
1594
- void Invoke(const OrtKernelContext* context,
1595
- const OrtValue* const* input_values,
1596
- size_t input_count,
1597
- OrtValue* const* output_values,
1598
- size_t output_count);
1599
- };
1600
-
1601
- /// <summary>
1602
- /// This entire structure is deprecated, but we not marking
1603
- /// it as a whole yet since we want to preserve for the next release.
1604
- /// </summary>
1605
- struct CustomOpApi {
1606
- CustomOpApi(const OrtApi& api) : api_(api) {}
1607
-
1608
- /** \deprecated use Ort::Value::GetTensorTypeAndShape()
1609
- * [[deprecated]]
1610
- * This interface produces a pointer that must be released. Not exception safe.
1611
- */
1612
- [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
1613
-
1614
- /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount()
1615
- * [[deprecated]]
1616
- * This interface is redundant.
1617
- */
1618
- [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1619
-
1620
- /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType()
1621
- * [[deprecated]]
1622
- * This interface is redundant.
1623
- */
1624
- [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
1625
-
1626
- /** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()
1627
- * [[deprecated]]
1628
- * This interface is redundant.
1629
- */
1630
- [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1631
-
1632
- /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1633
- * [[deprecated]]
1634
- * This interface is redundant.
1635
- */
1636
- [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1637
-
1638
- /** \deprecated
1639
- * [[deprecated]]
1640
- * This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue.
1641
- */
1642
- [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1643
-
1644
- /** \deprecated use Ort::Value::GetTensorMutableData()
1645
- * [[deprecated]]
1646
- * This interface is redundant.
1647
- */
1648
- template <typename T>
1649
- [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
1650
-
1651
- /** \deprecated use Ort::Value::GetTensorData()
1652
- * [[deprecated]]
1653
- * This interface is redundant.
1654
- */
1655
- template <typename T>
1656
- [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
1657
-
1658
- /** \deprecated use Ort::Value::GetTensorMemoryInfo()
1659
- * [[deprecated]]
1660
- * This interface is redundant.
1661
- */
1662
- [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1663
-
1664
- /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1665
- * [[deprecated]]
1666
- * This interface is redundant.
1667
- */
1668
- [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1669
-
1670
- /** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership.
1671
- * [[deprecated]]
1672
- * This interface is not exception safe.
1673
- */
1674
- [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
1675
-
1676
- /** \deprecated use Ort::KernelContext::GetInputCount
1677
- * [[deprecated]]
1678
- * This interface is redundant.
1679
- */
1680
- [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
1681
-
1682
- /** \deprecated use Ort::KernelContext::GetInput
1683
- * [[deprecated]]
1684
- * This interface is redundant.
1685
- */
1686
- [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1687
-
1688
- /** \deprecated use Ort::KernelContext::GetOutputCount
1689
- * [[deprecated]]
1690
- * This interface is redundant.
1691
- */
1692
- [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
1693
-
1694
- /** \deprecated use Ort::KernelContext::GetOutput
1695
- * [[deprecated]]
1696
- * This interface is redundant.
1697
- */
1698
- [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1699
-
1700
- /** \deprecated use Ort::KernelContext::GetGPUComputeStream
1701
- * [[deprecated]]
1702
- * This interface is redundant.
1703
- */
1704
- [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
1705
-
1706
- /** \deprecated use Ort::ThrowOnError()
1707
- * [[deprecated]]
1708
- * This interface is redundant.
1709
- */
1710
- [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
1711
-
1712
- /** \deprecated use Ort::OpAttr
1713
- * [[deprecated]]
1714
- * This interface is not exception safe.
1715
- */
1716
- [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
1717
- _In_ const void* data,
1718
- _In_ int len,
1719
- _In_ OrtOpAttrType type);
1720
-
1721
- /** \deprecated use Ort::OpAttr
1722
- * [[deprecated]]
1723
- * This interface is not exception safe.
1724
- */
1725
- [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1726
-
1727
- /** \deprecated use Ort::Op
1728
- * [[deprecated]]
1729
- * This interface is not exception safe.
1730
- */
1731
- [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1732
- _In_ const char* op_name,
1733
- _In_ const char* domain,
1734
- _In_ int version,
1735
- _In_opt_ const char** type_constraint_names,
1736
- _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1737
- _In_opt_ int type_constraint_count,
1738
- _In_opt_ const OrtOpAttr* const* attr_values,
1739
- _In_opt_ int attr_count,
1740
- _In_ int input_count,
1741
- _In_ int output_count);
1742
-
1743
- /** \deprecated use Ort::Op::Invoke
1744
- * [[deprecated]]
1745
- * This interface is redundant
1746
- */
1747
- [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
1748
- _In_ const OrtOp* ort_op,
1749
- _In_ const OrtValue* const* input_values,
1750
- _In_ int input_count,
1751
- _Inout_ OrtValue* const* output_values,
1752
- _In_ int output_count);
1753
-
1754
- /** \deprecated use Ort::Op for automatic lifespan management.
1755
- * [[deprecated]]
1756
- * This interface is not exception safe.
1757
- */
1758
- [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1759
-
1760
- /** \deprecated use Ort::KernelInfo for automatic lifespan management or for
1761
- * querying attributes
1762
- * [[deprecated]]
1763
- * This interface is redundant
1764
- */
1765
- template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1766
- [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1767
-
1768
- /** \deprecated use Ort::KernelInfo::Copy
1769
- * querying attributes
1770
- * [[deprecated]]
1771
- * This interface is not exception safe
1772
- */
1773
- [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
1774
-
1775
- /** \deprecated use Ort::KernelInfo for lifespan management
1776
- * querying attributes
1777
- * [[deprecated]]
1778
- * This interface is not exception safe
1779
- */
1780
- [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1781
-
1782
- private:
1783
- const OrtApi& api_;
1784
- };
1785
-
1786
- template <typename TOp, typename TKernel>
1787
- struct CustomOpBase : OrtCustomOp {
1788
- CustomOpBase() {
1789
- OrtCustomOp::version = ORT_API_VERSION;
1790
- OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
1791
- OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
1792
-
1793
- OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
1794
-
1795
- OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
1796
- OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
1797
- OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
1798
-
1799
- OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
1800
- OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
1801
-
1802
- OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
1803
- #if defined(_MSC_VER) && !defined(__clang__)
1804
- #pragma warning(push)
1805
- #pragma warning(disable : 26409)
1806
- #endif
1807
- OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
1808
- #if defined(_MSC_VER) && !defined(__clang__)
1809
- #pragma warning(pop)
1810
- #endif
1811
- OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
1812
- OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
1813
-
1814
- OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
1815
- OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
1816
- OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
1817
- OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
1818
- }
1819
-
1820
- // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
1821
- const char* GetExecutionProviderType() const { return nullptr; }
1822
-
1823
- // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
1824
- // (inputs and outputs are required by default)
1825
- OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
1826
- return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1827
- }
1828
-
1829
- OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
1830
- return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1831
- }
1832
-
1833
- // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
1834
- OrtMemType GetInputMemoryType(size_t /*index*/) const {
1835
- return OrtMemTypeDefault;
1836
- }
1837
-
1838
- // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
1839
- // should expect at least 1 argument.
1840
- int GetVariadicInputMinArity() const {
1841
- return 1;
1842
- }
1843
-
1844
- // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
1845
- // to a variadic input should be of the same type.
1846
- bool GetVariadicInputHomogeneity() const {
1847
- return true;
1848
- }
1849
-
1850
- // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
1851
- // should produce at least 1 output value.
1852
- int GetVariadicOutputMinArity() const {
1853
- return 1;
1854
- }
1855
-
1856
- // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
1857
- // produced by a variadic output should be of the same type.
1858
- bool GetVariadicOutputHomogeneity() const {
1859
- return true;
1860
- }
1861
-
1862
- // Declare list of session config entries used by this Custom Op.
1863
- // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
1864
- // This default implementation returns an empty vector of config entries.
1865
- std::vector<std::string> GetSessionConfigKeys() const {
1866
- return std::vector<std::string>{};
1867
- }
1868
-
1869
- protected:
1870
- // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
1871
- void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
1872
- };
1873
-
1874
- } // namespace Ort
1875
-
1876
- #include "onnxruntime_cxx_inline.h"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
headers/onnxruntime_cxx_inline.h DELETED
@@ -1,1874 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
- // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
- //
7
- // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
- // the main C++ file with implementation details.
9
-
10
- namespace Ort {
11
-
12
- namespace detail {
13
- inline void ThrowStatus(const Status& st) {
14
- std::string error_message = st.GetErrorMessage();
15
- OrtErrorCode error_code = st.GetErrorCode();
16
- ORT_CXX_API_THROW(std::move(error_message), error_code);
17
- }
18
- } // namespace detail
19
-
20
- inline void ThrowOnError(OrtStatus* ort_status) {
21
- if (ort_status) {
22
- Ort::Status st(ort_status);
23
- detail::ThrowStatus(st);
24
- }
25
- }
26
-
27
- inline void ThrowOnError(const Status& st) {
28
- if (st) {
29
- detail::ThrowStatus(st);
30
- }
31
- }
32
-
33
- inline Status::Status(OrtStatus* status) : Base<OrtStatus>{status} {
34
- }
35
-
36
- inline Status::Status(const std::exception& e) {
37
- p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
38
- }
39
-
40
- inline Status::Status(const Exception& e) {
41
- p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
42
- }
43
-
44
- inline std::string Status::GetErrorMessage() const {
45
- std::string message(GetApi().GetErrorMessage(p_));
46
- return message;
47
- }
48
-
49
- inline OrtErrorCode Status::GetErrorCode() const {
50
- return GetApi().GetErrorCode(p_);
51
- }
52
-
53
- // This template converts a C++ type into it's ONNXTensorElementDataType
54
- template <typename T>
55
- struct TypeToTensorType;
56
- template <>
57
- struct TypeToTensorType<float> {
58
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
59
- };
60
- template <>
61
- struct TypeToTensorType<Float16_t> {
62
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
63
- };
64
- template <>
65
- struct TypeToTensorType<BFloat16_t> {
66
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
67
- };
68
- template <>
69
- struct TypeToTensorType<double> {
70
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
71
- };
72
- template <>
73
- struct TypeToTensorType<int8_t> {
74
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
75
- };
76
- template <>
77
- struct TypeToTensorType<int16_t> {
78
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
79
- };
80
- template <>
81
- struct TypeToTensorType<int32_t> {
82
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
83
- };
84
- template <>
85
- struct TypeToTensorType<int64_t> {
86
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
87
- };
88
- template <>
89
- struct TypeToTensorType<uint8_t> {
90
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
91
- };
92
- template <>
93
- struct TypeToTensorType<uint16_t> {
94
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
95
- };
96
- template <>
97
- struct TypeToTensorType<uint32_t> {
98
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
99
- };
100
- template <>
101
- struct TypeToTensorType<uint64_t> {
102
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
103
- };
104
- template <>
105
- struct TypeToTensorType<bool> {
106
- static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
107
- };
108
-
109
- inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
110
- : allocator_(allocator), p_(p), size_(size) {
111
- }
112
-
113
- inline MemoryAllocation::~MemoryAllocation() {
114
- if (p_ != nullptr) {
115
- // We do not throw out of destructor
116
- auto ret = GetApi().AllocatorFree(allocator_, p_);
117
- static_cast<void>(ret);
118
- }
119
- }
120
-
121
- inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
122
- *this = std::move(o);
123
- }
124
-
125
- inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
126
- OrtAllocator* alloc = nullptr;
127
- void* p = nullptr;
128
- size_t sz = 0;
129
-
130
- // Swap out this
131
- std::swap(alloc, allocator_);
132
- std::swap(p, p_);
133
- std::swap(sz, size_);
134
-
135
- // Swap with incoming
136
- std::swap(allocator_, o.allocator_);
137
- std::swap(p_, o.p_);
138
- std::swap(size_, o.size_);
139
-
140
- // Destroy this instance if needed
141
- MemoryAllocation this_alloc(alloc, p, sz);
142
- return *this;
143
- }
144
-
145
- namespace detail {
146
-
147
- template <typename T>
148
- inline void* AllocatorImpl<T>::Alloc(size_t size) {
149
- void* out;
150
- ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
151
- return out;
152
- }
153
-
154
- template <typename T>
155
- inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
156
- void* out;
157
- ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
158
- MemoryAllocation result(this->p_, out, size);
159
- return result;
160
- }
161
-
162
- template <typename T>
163
- inline void AllocatorImpl<T>::Free(void* p) {
164
- ThrowOnError(GetApi().AllocatorFree(this->p_, p));
165
- }
166
-
167
- template <typename T>
168
- inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
169
- const OrtMemoryInfo* out;
170
- ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
171
- return ConstMemoryInfo{out};
172
- }
173
-
174
- } // namespace detail
175
-
176
- inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
177
- ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
178
- }
179
-
180
- inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
181
- ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
182
- }
183
-
184
- namespace detail {
185
-
186
- template <typename T>
187
- inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
188
- const char* name = nullptr;
189
- ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
190
- return std::string(name);
191
- }
192
-
193
- template <typename T>
194
- inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
195
- OrtAllocatorType type;
196
- ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
197
- return type;
198
- }
199
-
200
- template <typename T>
201
- inline int MemoryInfoImpl<T>::GetDeviceId() const {
202
- int id = 0;
203
- ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
204
- return id;
205
- }
206
-
207
- template <typename T>
208
- inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
209
- OrtMemoryInfoDeviceType type;
210
- GetApi().MemoryInfoGetDeviceType(this->p_, &type);
211
- return type;
212
- }
213
-
214
- template <typename T>
215
- inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
216
- OrtMemType type;
217
- ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
218
- return type;
219
- }
220
-
221
- template <typename T>
222
- template <typename U>
223
- inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
224
- int comp_result = 0;
225
- ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
226
- return comp_result == 0;
227
- }
228
-
229
- } // namespace detail
230
-
231
- inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
232
- OrtMemoryInfo* p;
233
- ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
234
- return MemoryInfo(p);
235
- }
236
-
237
- inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
238
- ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
239
- }
240
-
241
- namespace detail {
242
- template <typename T>
243
- inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
244
- AllocatorWithDefaultOptions allocator;
245
- return binding_utils::GetOutputNamesHelper(this->p_, allocator);
246
- }
247
-
248
- template <typename T>
249
- inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
250
- return binding_utils::GetOutputNamesHelper(this->p_, allocator);
251
- }
252
-
253
- template <typename T>
254
- inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
255
- AllocatorWithDefaultOptions allocator;
256
- return binding_utils::GetOutputValuesHelper(this->p_, allocator);
257
- }
258
-
259
- template <typename T>
260
- inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
261
- return binding_utils::GetOutputValuesHelper(this->p_, allocator);
262
- }
263
-
264
- template <typename T>
265
- inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
266
- ThrowOnError(GetApi().BindInput(this->p_, name, value));
267
- }
268
-
269
- template <typename T>
270
- inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
271
- ThrowOnError(GetApi().BindOutput(this->p_, name, value));
272
- }
273
-
274
- template <typename T>
275
- inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
276
- ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
277
- }
278
-
279
- template <typename T>
280
- inline void IoBindingImpl<T>::ClearBoundInputs() {
281
- GetApi().ClearBoundInputs(this->p_);
282
- }
283
-
284
- template <typename T>
285
- inline void IoBindingImpl<T>::ClearBoundOutputs() {
286
- GetApi().ClearBoundOutputs(this->p_);
287
- }
288
-
289
- template <typename T>
290
- inline void IoBindingImpl<T>::SynchronizeInputs() {
291
- ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
292
- }
293
-
294
- template <typename T>
295
- inline void IoBindingImpl<T>::SynchronizeOutputs() {
296
- ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
297
- }
298
-
299
- namespace binding_utils {
300
- inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
301
- std::vector<std::string> result;
302
- auto free_fn = detail::AllocatedFree(allocator);
303
- using Ptr = std::unique_ptr<void, decltype(free_fn)>;
304
-
305
- char* buffer = nullptr;
306
- size_t* lengths = nullptr;
307
- size_t count = 0;
308
- ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
309
-
310
- if (count == 0) {
311
- return result;
312
- }
313
-
314
- Ptr buffer_g(buffer, free_fn);
315
- Ptr lengths_g(lengths, free_fn);
316
-
317
- result.reserve(count);
318
- for (size_t i = 0; i < count; ++i) {
319
- auto sz = *lengths;
320
- result.emplace_back(buffer, sz);
321
- buffer += sz;
322
- ++lengths;
323
- }
324
- return result;
325
- }
326
-
327
- inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
328
- std::vector<Value> result;
329
- size_t owned = 0;
330
- size_t output_count = 0;
331
- // Lambda to release the buffer when no longer needed and
332
- // make sure that we destroy all instances on exception
333
- auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
334
- if (buffer) {
335
- while (owned < output_count) {
336
- auto* p = buffer + owned++;
337
- GetApi().ReleaseValue(*p);
338
- }
339
- allocator->Free(allocator, buffer);
340
- }
341
- };
342
- using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
343
-
344
- OrtValue** output_buffer = nullptr;
345
- ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
346
- if (output_count == 0) {
347
- return result;
348
- }
349
-
350
- Ptr buffer_g(output_buffer, free_fn);
351
-
352
- result.reserve(output_count);
353
- for (size_t i = 0; i < output_count; ++i) {
354
- result.emplace_back(output_buffer[i]);
355
- ++owned;
356
- }
357
- return result;
358
- }
359
-
360
- } // namespace binding_utils
361
- } // namespace detail
362
-
363
- inline IoBinding::IoBinding(Session& session) {
364
- ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
365
- }
366
-
367
- inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
368
- ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
369
- }
370
-
371
- inline ThreadingOptions::ThreadingOptions() {
372
- ThrowOnError(GetApi().CreateThreadingOptions(&p_));
373
- }
374
-
375
- inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
376
- ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
377
- return *this;
378
- }
379
-
380
- inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
381
- ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
382
- return *this;
383
- }
384
-
385
- inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
386
- ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
387
- return *this;
388
- }
389
-
390
- inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
391
- ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
392
- return *this;
393
- }
394
-
395
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
396
- ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
397
- return *this;
398
- }
399
-
400
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
401
- ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
402
- return *this;
403
- }
404
-
405
- inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
406
- ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
407
- return *this;
408
- }
409
-
410
- inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
411
- ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
412
- if (strcmp(logid, "onnxruntime-node") == 0) {
413
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
414
- } else {
415
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
416
- }
417
- }
418
-
419
- inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
420
- ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
421
- if (strcmp(logid, "onnxruntime-node") == 0) {
422
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
423
- } else {
424
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
425
- }
426
- }
427
-
428
- inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
429
- ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
430
- if (strcmp(logid, "onnxruntime-node") == 0) {
431
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
432
- } else {
433
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
434
- }
435
- }
436
-
437
- inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
438
- OrtLoggingLevel logging_level, _In_ const char* logid) {
439
- ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
440
- if (strcmp(logid, "onnxruntime-node") == 0) {
441
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
442
- } else {
443
- ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
444
- }
445
- }
446
-
447
- inline Env& Env::EnableTelemetryEvents() {
448
- ThrowOnError(GetApi().EnableTelemetryEvents(p_));
449
- return *this;
450
- }
451
-
452
- inline Env& Env::DisableTelemetryEvents() {
453
- ThrowOnError(GetApi().DisableTelemetryEvents(p_));
454
- return *this;
455
- }
456
-
457
- inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
458
- ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
459
- return *this;
460
- }
461
-
462
- inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
463
- ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
464
- return *this;
465
- }
466
-
467
- inline CustomOpDomain::CustomOpDomain(const char* domain) {
468
- ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
469
- }
470
-
471
- inline void CustomOpDomain::Add(const OrtCustomOp* op) {
472
- ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
473
- }
474
-
475
- inline RunOptions::RunOptions() {
476
- ThrowOnError(GetApi().CreateRunOptions(&p_));
477
- }
478
-
479
- inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
480
- ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
481
- return *this;
482
- }
483
-
484
- inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
485
- ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
486
- return *this;
487
- }
488
-
489
- inline int RunOptions::GetRunLogVerbosityLevel() const {
490
- int out;
491
- ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
492
- return out;
493
- }
494
-
495
- inline int RunOptions::GetRunLogSeverityLevel() const {
496
- int out;
497
- ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
498
- return out;
499
- }
500
-
501
- inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
502
- ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
503
- return *this;
504
- }
505
-
506
- inline const char* RunOptions::GetRunTag() const {
507
- const char* out;
508
- ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
509
- return out;
510
- }
511
-
512
- inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
513
- ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
514
- return *this;
515
- }
516
-
517
- inline RunOptions& RunOptions::SetTerminate() {
518
- ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
519
- return *this;
520
- }
521
-
522
- inline RunOptions& RunOptions::UnsetTerminate() {
523
- ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
524
- return *this;
525
- }
526
-
527
- namespace detail {
528
-
529
- template <typename T>
530
- inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
531
- OrtSessionOptions* out;
532
- ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
533
- return SessionOptions{out};
534
- }
535
-
536
- template <typename T>
537
- inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
538
- size_t size = 0;
539
- // Feed nullptr for the data buffer to query the true size of the string value
540
- Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
541
-
542
- std::string out;
543
- out.resize(size);
544
- Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
545
- out.resize(size - 1); // remove the terminating character '\0'
546
-
547
- return out;
548
- }
549
-
550
- template <typename T>
551
- inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
552
- int out = 0;
553
- Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
554
- return static_cast<bool>(out);
555
- }
556
-
557
- template <typename T>
558
- inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
559
- if (!this->HasConfigEntry(config_key)) {
560
- return def;
561
- }
562
-
563
- return this->GetConfigEntry(config_key);
564
- }
565
-
566
- template <typename T>
567
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
568
- ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
569
- return *this;
570
- }
571
-
572
- template <typename T>
573
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
574
- ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
575
- return *this;
576
- }
577
-
578
- template <typename T>
579
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
580
- ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
581
- return *this;
582
- }
583
-
584
- template <typename T>
585
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
586
- ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
587
- return *this;
588
- }
589
-
590
- template <typename T>
591
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
592
- ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
593
- return *this;
594
- }
595
-
596
- template <typename T>
597
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
598
- ThrowOnError(GetApi().DisableProfiling(this->p_));
599
- return *this;
600
- }
601
-
602
- template <typename T>
603
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
604
- ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
605
- return *this;
606
- }
607
-
608
- template <typename T>
609
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
610
- ThrowOnError(GetApi().EnableMemPattern(this->p_));
611
- return *this;
612
- }
613
-
614
- template <typename T>
615
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
616
- ThrowOnError(GetApi().DisableMemPattern(this->p_));
617
- return *this;
618
- }
619
-
620
- template <typename T>
621
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
622
- ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
623
- return *this;
624
- }
625
-
626
- template <typename T>
627
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
628
- ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
629
- return *this;
630
- }
631
-
632
- template <typename T>
633
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
634
- ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
635
- return *this;
636
- }
637
-
638
- template <typename T>
639
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
640
- ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
641
- return *this;
642
- }
643
-
644
- template <typename T>
645
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
646
- ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
647
- return *this;
648
- }
649
-
650
- template <typename T>
651
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
652
- ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
653
- return *this;
654
- }
655
-
656
- template <typename T>
657
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
658
- ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
659
- return *this;
660
- }
661
-
662
- template <typename T>
663
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
664
- ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
665
- return *this;
666
- }
667
-
668
- template <typename T>
669
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
670
- ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
671
- return *this;
672
- }
673
-
674
- template <typename T>
675
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
676
- const std::vector<Value>& ort_values) {
677
- const size_t inputs_num = names.size();
678
- if (inputs_num != ort_values.size()) {
679
- ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
680
- }
681
- std::vector<const char*> names_ptr;
682
- std::vector<const OrtValue*> ort_values_ptrs;
683
- names_ptr.reserve(inputs_num);
684
- ort_values_ptrs.reserve(inputs_num);
685
- for (size_t i = 0; i < inputs_num; ++i) {
686
- names_ptr.push_back(names[i].c_str());
687
- ort_values_ptrs.push_back(ort_values[i]);
688
- }
689
- ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
690
- return *this;
691
- }
692
-
693
- template <typename T>
694
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
695
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
696
- return *this;
697
- }
698
-
699
- template <typename T>
700
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
701
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
702
- return *this;
703
- }
704
-
705
- template <typename T>
706
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
707
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
708
- return *this;
709
- }
710
-
711
- template <typename T>
712
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
713
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
714
- return *this;
715
- }
716
-
717
- template <typename T>
718
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
719
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
720
- return *this;
721
- }
722
-
723
- template <typename T>
724
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
725
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
726
- return *this;
727
- }
728
-
729
- template <typename T>
730
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
731
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
732
- return *this;
733
- }
734
-
735
- template <typename T>
736
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
737
- const std::string& provider_name,
738
- const std::unordered_map<std::string, std::string>& provider_options) {
739
- auto num_entries = provider_options.size();
740
- std::vector<const char*> keys, values;
741
- if (num_entries > 0) {
742
- keys.reserve(num_entries);
743
- values.reserve(num_entries);
744
-
745
- for (const auto& entry : provider_options) {
746
- keys.push_back(entry.first.c_str());
747
- values.push_back(entry.second.c_str());
748
- }
749
- }
750
-
751
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
752
- keys.data(), values.data(), num_entries));
753
-
754
- return *this;
755
- }
756
-
757
- template <typename T>
758
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
759
- ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
760
- return *this;
761
- }
762
-
763
- template <typename T>
764
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
765
- ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
766
- return *this;
767
- }
768
-
769
- template <typename T>
770
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
771
- ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
772
- return *this;
773
- }
774
-
775
- template <typename T>
776
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
777
- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
778
- return *this;
779
- }
780
-
781
- template <typename T>
782
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
783
- const CustomOpConfigs& custom_op_configs) {
784
- // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
785
- // the custom op library.
786
- for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
787
- AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
788
- }
789
-
790
- ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
791
- return *this;
792
- }
793
-
794
- template <typename T>
795
- inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
796
- ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
797
- return *this;
798
- }
799
-
800
- /// Session
801
- template <typename T>
802
- inline size_t ConstSessionImpl<T>::GetInputCount() const {
803
- size_t out;
804
- ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
805
- return out;
806
- }
807
-
808
- template <typename T>
809
- inline size_t ConstSessionImpl<T>::GetOutputCount() const {
810
- size_t out;
811
- ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
812
- return out;
813
- }
814
-
815
- template <typename T>
816
- inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
817
- size_t out;
818
- ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
819
- return out;
820
- }
821
-
822
- template <typename T>
823
- inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
824
- char* out;
825
- ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
826
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
827
- }
828
-
829
- template <typename T>
830
- inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
831
- char* out;
832
- ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
833
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
834
- }
835
-
836
- template <typename T>
837
- inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
838
- char* out;
839
- ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
840
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
841
- }
842
-
843
- template <typename T>
844
- inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
845
- uint64_t out;
846
- ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
847
- return out;
848
- }
849
-
850
- template <typename T>
851
- inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
852
- OrtModelMetadata* out;
853
- ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
854
- return ModelMetadata{out};
855
- }
856
-
857
- template <typename T>
858
- inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
859
- OrtTypeInfo* out;
860
- ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
861
- return TypeInfo{out};
862
- }
863
-
864
- template <typename T>
865
- inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
866
- OrtTypeInfo* out;
867
- ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
868
- return TypeInfo{out};
869
- }
870
-
871
- template <typename T>
872
- inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
873
- OrtTypeInfo* out;
874
- ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
875
- return TypeInfo{out};
876
- }
877
-
878
- template <typename T>
879
- inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
880
- const char* const* output_names, size_t output_count) {
881
- std::vector<Value> output_values;
882
- output_values.reserve(output_count);
883
- for (size_t i = 0; i < output_count; i++)
884
- output_values.emplace_back(nullptr);
885
- Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
886
- return output_values;
887
- }
888
-
889
- template <typename T>
890
- inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
891
- const char* const* output_names, Value* output_values, size_t output_count) {
892
- static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
893
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
894
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
895
- ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
896
- }
897
-
898
- template <typename T>
899
- inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
900
- ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
901
- }
902
-
903
- template <typename T>
904
- inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
905
- char* out = nullptr;
906
- ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
907
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
908
- }
909
-
910
- } // namespace detail
911
-
912
- inline SessionOptions::SessionOptions() {
913
- ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
914
- }
915
-
916
- /// CustomOpConfigs
917
- inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
918
- std::string config_key = "custom_op.";
919
-
920
- config_key += custom_op_name;
921
- config_key += ".";
922
- config_key += config;
923
-
924
- return config_key;
925
- }
926
-
927
- inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
928
- const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
929
- flat_configs_[full_flat_key] = config_value;
930
- return *this;
931
- }
932
-
933
- inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
934
- return flat_configs_;
935
- }
936
-
937
- inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
938
- ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
939
- }
940
-
941
- inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
942
- OrtPrepackedWeightsContainer* prepacked_weights_container) {
943
- ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
944
- }
945
-
946
- inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
947
- ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
948
- }
949
-
950
- inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
951
- const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
952
- ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
953
- prepacked_weights_container, &this->p_));
954
- }
955
-
956
- inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
957
- char* out;
958
- ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
959
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
960
- }
961
-
962
- inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
963
- char* out;
964
- ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
965
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
966
- }
967
-
968
- inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
969
- char* out;
970
- ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
971
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
972
- }
973
-
974
- inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
975
- char* out;
976
- ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
977
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
978
- }
979
-
980
- inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
981
- char* out;
982
- ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
983
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
984
- }
985
-
986
- inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
987
- char* out;
988
- ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
989
- return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990
- }
991
-
992
- inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
993
- auto deletor = detail::AllocatedFree(allocator);
994
- std::vector<AllocatedStringPtr> result;
995
-
996
- char** out = nullptr;
997
- int64_t num_keys = 0;
998
- ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
999
- if (num_keys <= 0) {
1000
- return result;
1001
- }
1002
-
1003
- // array of pointers will be freed
1004
- std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1005
- // reserve may throw
1006
- auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1007
- std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1008
- result.reserve(static_cast<size_t>(num_keys));
1009
- strings_guard.release();
1010
- for (int64_t i = 0; i < num_keys; ++i) {
1011
- result.push_back(AllocatedStringPtr(out[i], deletor));
1012
- }
1013
-
1014
- return result;
1015
- }
1016
-
1017
- inline int64_t ModelMetadata::GetVersion() const {
1018
- int64_t out;
1019
- ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1020
- return out;
1021
- }
1022
-
1023
- namespace detail {
1024
-
1025
- template <typename T>
1026
- inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1027
- ONNXTensorElementDataType out;
1028
- ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1029
- return out;
1030
- }
1031
-
1032
- template <typename T>
1033
- inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1034
- size_t out;
1035
- ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1036
- return static_cast<size_t>(out);
1037
- }
1038
-
1039
- template <typename T>
1040
- inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1041
- size_t out;
1042
- ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1043
- return out;
1044
- }
1045
-
1046
- template <typename T>
1047
- inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1048
- ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1049
- }
1050
-
1051
- template <typename T>
1052
- inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1053
- ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1054
- }
1055
-
1056
- template <typename T>
1057
- inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1058
- std::vector<int64_t> out(GetDimensionsCount(), 0);
1059
- ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1060
- return out;
1061
- }
1062
-
1063
- } // namespace detail
1064
-
1065
- namespace detail {
1066
- template <typename T>
1067
- inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1068
- const OrtTensorTypeAndShapeInfo* out;
1069
- ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1070
- return ConstTensorTypeAndShapeInfo{out};
1071
- }
1072
-
1073
- template <typename T>
1074
- inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1075
- const OrtSequenceTypeInfo* out;
1076
- ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1077
- return ConstSequenceTypeInfo{out};
1078
- }
1079
-
1080
- template <typename T>
1081
- inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1082
- const OrtMapTypeInfo* out;
1083
- ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1084
- return ConstMapTypeInfo{out};
1085
- }
1086
-
1087
- template <typename T>
1088
- inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1089
- ONNXType out;
1090
- ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1091
- return out;
1092
- }
1093
-
1094
- } // namespace detail
1095
-
1096
- namespace detail {
1097
- template <typename T>
1098
- inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1099
- OrtTypeInfo* output;
1100
- ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1101
- return TypeInfo{output};
1102
- }
1103
-
1104
- } // namespace detail
1105
-
1106
- namespace detail {
1107
- template <typename T>
1108
- inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1109
- ONNXTensorElementDataType out;
1110
- ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1111
- return out;
1112
- }
1113
-
1114
- template <typename T>
1115
- inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1116
- OrtTypeInfo* output;
1117
- ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1118
- return TypeInfo{output};
1119
- }
1120
- } // namespace detail
1121
-
1122
- namespace detail {
1123
-
1124
- template <typename T>
1125
- template <typename R>
1126
- inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1127
- ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1128
- }
1129
-
1130
- template <typename T>
1131
- inline bool ConstValueImpl<T>::IsTensor() const {
1132
- int out;
1133
- ThrowOnError(GetApi().IsTensor(this->p_, &out));
1134
- return out != 0;
1135
- }
1136
-
1137
- template <typename T>
1138
- inline bool ConstValueImpl<T>::HasValue() const {
1139
- int out;
1140
- ThrowOnError(GetApi().HasValue(this->p_, &out));
1141
- return out != 0;
1142
- }
1143
-
1144
- template <typename T>
1145
- inline size_t ConstValueImpl<T>::GetCount() const {
1146
- size_t out;
1147
- ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1148
- return out;
1149
- }
1150
-
1151
- template <typename T>
1152
- inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1153
- OrtValue* out;
1154
- ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1155
- return Value{out};
1156
- }
1157
-
1158
- template <typename T>
1159
- inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1160
- size_t out;
1161
- ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1162
- return out;
1163
- }
1164
-
1165
- template <typename T>
1166
- inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1167
- size_t out;
1168
- ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1169
- return out;
1170
- }
1171
-
1172
- template <typename T>
1173
- template <typename R>
1174
- inline const R* ConstValueImpl<T>::GetTensorData() const {
1175
- R* out;
1176
- ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1177
- return out;
1178
- }
1179
-
1180
- template <typename T>
1181
- inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1182
- void* out;
1183
- ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1184
- return out;
1185
- }
1186
-
1187
- template <typename T>
1188
- inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1189
- OrtTypeInfo* output;
1190
- ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1191
- return TypeInfo{output};
1192
- }
1193
-
1194
- template <typename T>
1195
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1196
- OrtTensorTypeAndShapeInfo* output;
1197
- ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1198
- return TensorTypeAndShapeInfo{output};
1199
- }
1200
-
1201
- template <typename T>
1202
- inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1203
- const OrtMemoryInfo* mem_info;
1204
- ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1205
- return ConstMemoryInfo(mem_info);
1206
- }
1207
-
1208
- template <typename T>
1209
- inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1210
- ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1211
- }
1212
-
1213
- template <typename T>
1214
- inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1215
- ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1216
- }
1217
-
1218
- #if !defined(DISABLE_SPARSE_TENSORS)
1219
- template <typename T>
1220
- inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1221
- OrtSparseFormat format;
1222
- ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1223
- return format;
1224
- }
1225
-
1226
- template <typename T>
1227
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1228
- OrtTensorTypeAndShapeInfo* output;
1229
- ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1230
- return TensorTypeAndShapeInfo{output};
1231
- }
1232
-
1233
- template <typename T>
1234
- inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1235
- OrtTensorTypeAndShapeInfo* output;
1236
- ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1237
- return TensorTypeAndShapeInfo{output};
1238
- }
1239
-
1240
- template <typename T>
1241
- template <typename R>
1242
- inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1243
- const void* out;
1244
- ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1245
- return reinterpret_cast<const R*>(out);
1246
- }
1247
-
1248
- template <typename T>
1249
- inline bool ConstValueImpl<T>::IsSparseTensor() const {
1250
- int out;
1251
- ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1252
- return out != 0;
1253
- }
1254
-
1255
- template <typename T>
1256
- template <typename R>
1257
- inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1258
- const void* out;
1259
- ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1260
- return reinterpret_cast<const R*>(out);
1261
- }
1262
-
1263
- #endif
1264
-
1265
- template <typename T>
1266
- void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1267
- ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1268
- }
1269
-
1270
- template <typename T>
1271
- void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1272
- ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1273
- }
1274
-
1275
- template <typename T>
1276
- void* ValueImpl<T>::GetTensorMutableRawData() {
1277
- void* out;
1278
- ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1279
- return out;
1280
- }
1281
-
1282
- template <typename T>
1283
- template <typename R>
1284
- R* ValueImpl<T>::GetTensorMutableData() {
1285
- R* out;
1286
- ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1287
- return out;
1288
- }
1289
-
1290
- template <typename T>
1291
- template <typename R>
1292
- R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1293
- static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1294
- R* out;
1295
- ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1296
- return *out;
1297
- }
1298
-
1299
- #if !defined(DISABLE_SPARSE_TENSORS)
1300
- template <typename T>
1301
- void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1302
- ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1303
- }
1304
-
1305
- template <typename T>
1306
- void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1307
- ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1308
- }
1309
-
1310
- template <typename T>
1311
- void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1312
- ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1313
- }
1314
-
1315
- template <typename T>
1316
- void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1317
- const int64_t* indices_data, size_t indices_num) {
1318
- ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1319
- values_param.values_shape_len, values_param.data.p_data,
1320
- indices_data, indices_num));
1321
- }
1322
-
1323
- template <typename T>
1324
- void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1325
- const OrtSparseValuesParam& values,
1326
- const int64_t* inner_indices_data, size_t inner_indices_num,
1327
- const int64_t* outer_indices_data, size_t outer_indices_num) {
1328
- ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1329
- inner_indices_data, inner_indices_num,
1330
- outer_indices_data, outer_indices_num));
1331
- }
1332
-
1333
- template <typename T>
1334
- void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1335
- const OrtSparseValuesParam& values,
1336
- const Shape& indices_shape,
1337
- const int32_t* indices_data) {
1338
- ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1339
- indices_shape.shape, indices_shape.shape_len,
1340
- indices_data));
1341
- }
1342
-
1343
- #endif // !defined(DISABLE_SPARSE_TENSORS)
1344
-
1345
- } // namespace detail
1346
-
1347
- template <typename T>
1348
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1349
- return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1350
- }
1351
-
1352
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1353
- ONNXTensorElementDataType type) {
1354
- OrtValue* out;
1355
- ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1356
- return Value{out};
1357
- }
1358
-
1359
- template <typename T>
1360
- inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1361
- return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1362
- }
1363
-
1364
- inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1365
- OrtValue* out;
1366
- ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1367
- return Value{out};
1368
- }
1369
-
1370
- #if !defined(DISABLE_SPARSE_TENSORS)
1371
-
1372
- template <typename T>
1373
- inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1374
- const Shape& values_shape) {
1375
- return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1376
- }
1377
-
1378
- inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1379
- const Shape& values_shape, ONNXTensorElementDataType type) {
1380
- OrtValue* out;
1381
- ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1382
- values_shape.shape, values_shape.shape_len, type, &out));
1383
- return Value{out};
1384
- }
1385
-
1386
- template <typename T>
1387
- inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1388
- return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1389
- }
1390
-
1391
- inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1392
- ONNXTensorElementDataType type) {
1393
- OrtValue* out;
1394
- ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1395
- return Value{out};
1396
- }
1397
- #endif // !defined(DISABLE_SPARSE_TENSORS)
1398
-
1399
- inline Value Value::CreateMap(Value& keys, Value& values) {
1400
- OrtValue* out;
1401
- OrtValue* inputs[2] = {keys, values};
1402
- ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1403
- return Value{out};
1404
- }
1405
-
1406
- inline Value Value::CreateSequence(std::vector<Value>& values) {
1407
- OrtValue* out;
1408
- std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
1409
- ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1410
- return Value{out};
1411
- }
1412
-
1413
- template <typename T>
1414
- inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1415
- OrtValue* out;
1416
- ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1417
- return Value{out};
1418
- }
1419
-
1420
- //
1421
- // Custom OP Inlines
1422
- //
1423
- inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1424
- }
1425
-
1426
- inline size_t KernelContext::GetInputCount() const {
1427
- size_t out = 0;
1428
- Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1429
- return out;
1430
- }
1431
-
1432
- inline size_t KernelContext::GetOutputCount() const {
1433
- size_t out = 0;
1434
- Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1435
- return out;
1436
- }
1437
-
1438
- inline ConstValue KernelContext::GetInput(size_t index) const {
1439
- const OrtValue* out = nullptr;
1440
- Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1441
- return ConstValue{out};
1442
- }
1443
-
1444
- inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1445
- OrtValue* out = nullptr;
1446
- Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1447
- return UnownedValue(out);
1448
- }
1449
-
1450
- inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1451
- OrtValue* out = nullptr;
1452
- Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1453
- return UnownedValue(out);
1454
- }
1455
-
1456
- inline void* KernelContext::GetGPUComputeStream() const {
1457
- void* out = nullptr;
1458
- Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1459
- return out;
1460
- }
1461
-
1462
- inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1463
- Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1464
- }
1465
-
1466
- namespace detail {
1467
- template <typename T>
1468
- inline KernelInfo KernelInfoImpl<T>::Copy() const {
1469
- OrtKernelInfo* info_copy = nullptr;
1470
- Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1471
- return KernelInfo{info_copy};
1472
- }
1473
-
1474
- template <typename T>
1475
- inline size_t KernelInfoImpl<T>::GetInputCount() const {
1476
- size_t out = 0;
1477
- ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1478
- return out;
1479
- }
1480
-
1481
- template <typename T>
1482
- inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1483
- size_t out = 0;
1484
- ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1485
- return out;
1486
- }
1487
-
1488
- template <typename T>
1489
- inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1490
- size_t size = 0;
1491
-
1492
- // Feed nullptr for the data buffer to query the true size of the string value
1493
- Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1494
-
1495
- std::string out;
1496
- out.resize(size);
1497
- Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1498
- out.resize(size - 1); // remove the terminating character '\0'
1499
-
1500
- return out;
1501
- }
1502
-
1503
- template <typename T>
1504
- inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1505
- size_t size = 0;
1506
-
1507
- // Feed nullptr for the data buffer to query the true size of the string value
1508
- Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1509
-
1510
- std::string out;
1511
- out.resize(size);
1512
- Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1513
- out.resize(size - 1); // remove the terminating character '\0'
1514
-
1515
- return out;
1516
- }
1517
-
1518
- template <typename T>
1519
- inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1520
- OrtTypeInfo* out = nullptr;
1521
- ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1522
- return TypeInfo{out};
1523
- }
1524
-
1525
- template <typename T>
1526
- inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1527
- OrtTypeInfo* out = nullptr;
1528
- ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1529
- return TypeInfo{out};
1530
- }
1531
-
1532
- template <typename T>
1533
- inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1534
- OrtValue* out = nullptr;
1535
- ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1536
- return Value{out};
1537
- }
1538
-
1539
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1540
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1541
- }
1542
-
1543
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1544
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1545
- }
1546
-
1547
- inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1548
- size_t size = 0;
1549
- // Feed nullptr for the data buffer to query the true size of the string attribute
1550
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1551
-
1552
- std::string out;
1553
- out.resize(size);
1554
- Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1555
- out.resize(size - 1); // remove the terminating character '\0'
1556
- out.swap(result);
1557
- }
1558
-
1559
- inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1560
- size_t size = 0;
1561
- // Feed nullptr for the data buffer to query the true size of the attribute
1562
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1563
-
1564
- std::vector<float> out;
1565
- out.resize(size);
1566
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1567
- out.swap(result);
1568
- }
1569
-
1570
- inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1571
- size_t size = 0;
1572
-
1573
- // Feed nullptr for the data buffer to query the true size of the attribute
1574
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1575
-
1576
- std::vector<int64_t> out;
1577
- out.resize(size);
1578
- Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1579
- out.swap(result);
1580
- }
1581
- } // namespace detail
1582
-
1583
- inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1584
-
1585
- inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1586
-
1587
- inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1588
- const char** type_constraint_names,
1589
- const ONNXTensorElementDataType* type_constraint_values,
1590
- size_t type_constraint_count,
1591
- const OpAttr* attr_values, size_t attr_count,
1592
- size_t input_count, size_t output_count) {
1593
- static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1594
- "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1595
- auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1596
- OrtOp* op;
1597
- Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1598
- static_cast<int>(type_constraint_count),
1599
- attr_input_values,
1600
- static_cast<int>(attr_count),
1601
- static_cast<int>(input_count),
1602
- static_cast<int>(output_count), &op));
1603
- return Op{op};
1604
- }
1605
-
1606
- inline void Op::Invoke(const OrtKernelContext* context,
1607
- const Value* input_values,
1608
- size_t input_count,
1609
- Value* output_values,
1610
- size_t output_count) {
1611
- static_assert(sizeof(Value) == sizeof(OrtValue*),
1612
- "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1613
- auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1614
- auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1615
- Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1616
- ort_output_values, static_cast<int>(output_count)));
1617
- }
1618
-
1619
- inline void Op::Invoke(const OrtKernelContext* context,
1620
- const OrtValue* const* input_values,
1621
- size_t input_count,
1622
- OrtValue* const* output_values,
1623
- size_t output_count) {
1624
- Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1625
- output_values, static_cast<int>(output_count)));
1626
- }
1627
-
1628
- inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
1629
- Ort::ThrowOnError(status);
1630
- }
1631
-
1632
- template <>
1633
- inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1634
- float out;
1635
- Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
1636
- return out;
1637
- }
1638
-
1639
- template <>
1640
- inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1641
- int64_t out;
1642
- Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
1643
- return out;
1644
- }
1645
-
1646
- template <>
1647
- inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1648
- size_t size = 0;
1649
- std::string out;
1650
-
1651
- // Feed nullptr for the data buffer to query the true size of the string attribute
1652
- OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
1653
-
1654
- if (status == nullptr) {
1655
- out.resize(size);
1656
- Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
1657
- out.resize(size - 1); // remove the terminating character '\0'
1658
- } else {
1659
- Ort::ThrowOnError(status);
1660
- }
1661
- return out;
1662
- }
1663
-
1664
- template <>
1665
- inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1666
- size_t size = 0;
1667
- std::vector<float> out;
1668
-
1669
- // Feed nullptr for the data buffer to query the true size of the attribute
1670
- OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
1671
-
1672
- if (status == nullptr) {
1673
- out.resize(size);
1674
- Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
1675
- } else {
1676
- Ort::ThrowOnError(status);
1677
- }
1678
- return out;
1679
- }
1680
-
1681
- template <>
1682
- inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1683
- size_t size = 0;
1684
- std::vector<int64_t> out;
1685
-
1686
- // Feed nullptr for the data buffer to query the true size of the attribute
1687
- OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
1688
-
1689
- if (status == nullptr) {
1690
- out.resize(size);
1691
- Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
1692
- } else {
1693
- Ort::ThrowOnError(status);
1694
- }
1695
- return out;
1696
- }
1697
- inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
1698
- OrtTensorTypeAndShapeInfo* out;
1699
- Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
1700
- return out;
1701
- }
1702
-
1703
- inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1704
- size_t out;
1705
- Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
1706
- return out;
1707
- }
1708
-
1709
- inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
1710
- ONNXTensorElementDataType out;
1711
- Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
1712
- return out;
1713
- }
1714
-
1715
- inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1716
- size_t out;
1717
- Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1718
- return out;
1719
- }
1720
-
1721
- inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
1722
- Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
1723
- }
1724
-
1725
- inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
1726
- Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
1727
- }
1728
-
1729
- template <typename T>
1730
- inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
1731
- T* data;
1732
- Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
1733
- return data;
1734
- }
1735
-
1736
- inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
1737
- const OrtMemoryInfo* mem_info;
1738
- Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
1739
- return mem_info;
1740
- }
1741
-
1742
- template <typename T>
1743
- inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
1744
- T* data = nullptr;
1745
- Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
1746
- return data;
1747
- }
1748
-
1749
- inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
1750
- size_t out;
1751
- Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1752
- std::vector<int64_t> output(out);
1753
- Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
1754
- return output;
1755
- }
1756
-
1757
- inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
1758
- api_.ReleaseTensorTypeAndShapeInfo(input);
1759
- }
1760
-
1761
- inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
1762
- size_t out;
1763
- Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
1764
- return out;
1765
- }
1766
-
1767
- inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
1768
- const OrtValue* out;
1769
- Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
1770
- return out;
1771
- }
1772
-
1773
- inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
1774
- size_t out;
1775
- Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
1776
- return out;
1777
- }
1778
-
1779
- inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
1780
- _In_ const int64_t* dim_values, size_t dim_count) {
1781
- OrtValue* out;
1782
- Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
1783
- return out;
1784
- }
1785
-
1786
- inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
1787
- void* out;
1788
- Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
1789
- return out;
1790
- }
1791
-
1792
- inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
1793
- _In_ const void* data,
1794
- _In_ int len,
1795
- _In_ OrtOpAttrType type) {
1796
- OrtOpAttr* op_attr{};
1797
- Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
1798
- return op_attr;
1799
- }
1800
-
1801
- inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
1802
- api_.ReleaseOpAttr(op_attr);
1803
- }
1804
-
1805
- inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
1806
- _In_ const char* op_name,
1807
- _In_ const char* domain,
1808
- _In_ int version,
1809
- _In_opt_ const char** type_constraint_names,
1810
- _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1811
- _In_opt_ int type_constraint_count,
1812
- _In_opt_ const OrtOpAttr* const* attr_values,
1813
- _In_opt_ int attr_count,
1814
- _In_ int input_count,
1815
- _In_ int output_count) {
1816
- OrtOp* ort_op{};
1817
- Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1818
- type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
1819
- return ort_op;
1820
- }
1821
-
1822
- inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
1823
- _In_ const OrtOp* ort_op,
1824
- _In_ const OrtValue* const* input_values,
1825
- _In_ int input_count,
1826
- _Inout_ OrtValue* const* output_values,
1827
- _In_ int output_count) {
1828
- Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
1829
- }
1830
-
1831
- inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
1832
- api_.ReleaseOp(ort_op);
1833
- }
1834
-
1835
- inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
1836
- OrtKernelInfo* info_copy{};
1837
- Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
1838
- return info_copy;
1839
- }
1840
-
1841
- inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
1842
- api_.ReleaseKernelInfo(info_copy);
1843
- }
1844
-
1845
- inline std::vector<std::string> GetAvailableProviders() {
1846
- int len;
1847
- char** providers;
1848
- ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1849
- std::vector<std::string> available_providers(providers, providers + len);
1850
- ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1851
- return available_providers;
1852
- }
1853
-
1854
- SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
1855
-
1856
- template <typename TOp, typename TKernel>
1857
- void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1858
- ConstSessionOptions options) const {
1859
- const TOp* derived = static_cast<const TOp*>(this);
1860
- std::vector<std::string> keys = derived->GetSessionConfigKeys();
1861
-
1862
- out.reserve(keys.size());
1863
-
1864
- std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1865
- const size_t prefix_size = config_entry_key.length();
1866
-
1867
- for (const auto& key : keys) {
1868
- config_entry_key.resize(prefix_size);
1869
- config_entry_key.append(key);
1870
- out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1871
- }
1872
- }
1873
-
1874
- } // namespace Ort