csukuangfj commited on
Commit
b974d63
·
1 Parent(s): 0a9fe68

remove extra files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. arm64-v8a/include/onnxruntime/core/common/basic_types.h +0 -19
  2. arm64-v8a/include/onnxruntime/core/common/code_location.h +0 -58
  3. arm64-v8a/include/onnxruntime/core/common/common.h +0 -287
  4. arm64-v8a/include/onnxruntime/core/common/const_pointer_container.h +0 -85
  5. arm64-v8a/include/onnxruntime/core/common/denormal.h +0 -12
  6. arm64-v8a/include/onnxruntime/core/common/eigen_common_wrapper.h +0 -62
  7. arm64-v8a/include/onnxruntime/core/common/exceptions.h +0 -71
  8. arm64-v8a/include/onnxruntime/core/common/gpu_profiler_common.h +0 -472
  9. arm64-v8a/include/onnxruntime/core/common/gsl.h +0 -6
  10. arm64-v8a/include/onnxruntime/core/common/hash_combine.h +0 -21
  11. arm64-v8a/include/onnxruntime/core/common/inlined_containers.h +0 -175
  12. arm64-v8a/include/onnxruntime/core/common/inlined_containers_fwd.h +0 -147
  13. arm64-v8a/include/onnxruntime/core/common/logging/capture.h +0 -115
  14. arm64-v8a/include/onnxruntime/core/common/logging/isink.h +0 -41
  15. arm64-v8a/include/onnxruntime/core/common/logging/logging.h +0 -337
  16. arm64-v8a/include/onnxruntime/core/common/logging/macros.h +0 -278
  17. arm64-v8a/include/onnxruntime/core/common/logging/severity.h +0 -22
  18. arm64-v8a/include/onnxruntime/core/common/make_string.h +0 -126
  19. arm64-v8a/include/onnxruntime/core/common/narrow.h +0 -77
  20. arm64-v8a/include/onnxruntime/core/common/optional.h +0 -23
  21. arm64-v8a/include/onnxruntime/core/common/parse_string.h +0 -85
  22. arm64-v8a/include/onnxruntime/core/common/profiler_common.h +0 -93
  23. arm64-v8a/include/onnxruntime/core/common/span_utils.h +0 -88
  24. arm64-v8a/include/onnxruntime/core/common/spin_pause.h +0 -28
  25. arm64-v8a/include/onnxruntime/core/common/status.h +0 -195
  26. arm64-v8a/include/onnxruntime/core/common/string_helper.h +0 -11
  27. arm64-v8a/include/onnxruntime/core/framework/alloc_kind.h +0 -36
  28. arm64-v8a/include/onnxruntime/core/framework/allocator.h +0 -194
  29. arm64-v8a/include/onnxruntime/core/framework/buffer_deleter.h +0 -36
  30. arm64-v8a/include/onnxruntime/core/framework/customregistry.h +0 -60
  31. arm64-v8a/include/onnxruntime/core/framework/data_types.h +0 -1062
  32. arm64-v8a/include/onnxruntime/core/framework/data_types_internal.h +0 -569
  33. arm64-v8a/include/onnxruntime/core/framework/endian.h +0 -27
  34. arm64-v8a/include/onnxruntime/core/framework/execution_provider.h +0 -340
  35. arm64-v8a/include/onnxruntime/core/framework/float16.h +0 -159
  36. arm64-v8a/include/onnxruntime/core/framework/framework_common.h +0 -22
  37. arm64-v8a/include/onnxruntime/core/framework/func_api.h +0 -27
  38. arm64-v8a/include/onnxruntime/core/framework/kernel_def_builder.h +0 -353
  39. arm64-v8a/include/onnxruntime/core/framework/kernel_registry.h +0 -91
  40. arm64-v8a/include/onnxruntime/core/framework/op_kernel.h +0 -387
  41. arm64-v8a/include/onnxruntime/core/framework/op_kernel_context.h +0 -237
  42. arm64-v8a/include/onnxruntime/core/framework/op_kernel_info.h +0 -63
  43. arm64-v8a/include/onnxruntime/core/framework/op_node_proto_helper.h +0 -167
  44. arm64-v8a/include/onnxruntime/core/framework/ort_value.h +0 -123
  45. arm64-v8a/include/onnxruntime/core/framework/ortdevice.h +0 -74
  46. arm64-v8a/include/onnxruntime/core/framework/ortmemoryinfo.h +0 -87
  47. arm64-v8a/include/onnxruntime/core/framework/provider_options.h +0 -18
  48. arm64-v8a/include/onnxruntime/core/framework/provider_options_utils.h +0 -164
  49. arm64-v8a/include/onnxruntime/core/framework/provider_shutdown.h +0 -8
  50. arm64-v8a/include/onnxruntime/core/framework/run_options.h +0 -49
arm64-v8a/include/onnxruntime/core/common/basic_types.h DELETED
@@ -1,19 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <cstdint>
7
-
8
- namespace onnxruntime {
9
-
10
- /** A computed hash value. */
11
- using HashValue = uint64_t;
12
-
13
- /** The type of an argument (input or output).*/
14
- enum class ArgType : uint8_t {
15
- kInput,
16
- kOutput,
17
- };
18
-
19
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/code_location.h DELETED
@@ -1,58 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <sstream>
7
- #include <string>
8
- #include <vector>
9
-
10
- namespace onnxruntime {
11
- /**
12
- CodeLocation captures information on where in the source code a message came from.
13
- */
14
- struct CodeLocation {
15
- /**
16
- @param file_path Usually the value of __FILE__
17
- @param line Usually the value of __LINE__
18
- @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__
19
- */
20
- CodeLocation(const char* file_path, const int line, const char* func)
21
- : file_and_path{file_path}, line_num{line}, function{func} {
22
- }
23
-
24
- /**
25
- @param file_path Usually the value of __FILE__
26
- @param line Usually the value of __LINE__
27
- @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__
28
- @param stacktrace Stacktrace from source of message.
29
- */
30
- CodeLocation(const char* file_path, const int line, const char* func, const std::vector<std::string>& stacktrace)
31
- : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) {
32
- }
33
-
34
- std::string FileNoPath() const {
35
- // assuming we always have work to do, so not trying to avoid creating a new string if
36
- // no path was removed.
37
- return file_and_path.substr(file_and_path.find_last_of("/\\") + 1);
38
- }
39
-
40
- enum Format {
41
- kFilename,
42
- kFilenameAndPath
43
- };
44
-
45
- std::string ToString(Format format = Format::kFilename) const {
46
- std::ostringstream out;
47
- out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function;
48
- return out.str();
49
- }
50
- //utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
51
- const std::string file_and_path;
52
- const int line_num;
53
- //utf-8
54
- const std::string function;
55
- const std::vector<std::string> stacktrace;
56
- };
57
-
58
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/common.h DELETED
@@ -1,287 +0,0 @@
1
- /**
2
- * Copyright (c) 2016-present, Facebook, Inc.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
- // Portions Copyright (c) Microsoft Corporation
17
-
18
- #pragma once
19
-
20
- #include <climits>
21
- #include <cstring>
22
- #include <algorithm>
23
- #include <chrono>
24
- #include <functional>
25
- #include <memory>
26
- #include <numeric>
27
- #include <set>
28
- #include <sstream>
29
- #include <string>
30
- #include <type_traits>
31
- #include <unordered_map>
32
- #include <utility>
33
- #include <vector>
34
-
35
- #include "core/common/code_location.h"
36
- #include "core/common/exceptions.h"
37
- #include "core/common/make_string.h"
38
- #include "core/common/status.h"
39
-
40
-
41
- namespace onnxruntime {
42
-
43
- using TimePoint = std::chrono::high_resolution_clock::time_point;
44
-
45
- #ifdef _WIN32
46
- #define ORT_UNUSED_PARAMETER(x) (x)
47
- #else
48
- #define ORT_UNUSED_PARAMETER(x) (void)(x)
49
- #endif
50
-
51
- #ifndef ORT_HAVE_ATTRIBUTE
52
- #ifdef __has_attribute
53
- #define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
54
- #else
55
- #define ORT_HAVE_ATTRIBUTE(x) 0
56
- #endif
57
- #endif
58
-
59
- // ORT_ATTRIBUTE_UNUSED
60
- //
61
- // Prevents the compiler from complaining about or optimizing away variables
62
- // that appear unused on Linux
63
- #if ORT_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__))
64
- #undef ORT_ATTRIBUTE_UNUSED
65
- #define ORT_ATTRIBUTE_UNUSED __attribute__((__unused__))
66
- #else
67
- #define ORT_ATTRIBUTE_UNUSED
68
- #endif
69
-
70
- #ifdef ORT_NO_EXCEPTIONS
71
- // Print the given final message, the message must be a null terminated char*
72
- // ORT will abort after printing the message.
73
- // For Android, will print to Android system log
74
- // For other platforms, will print to stderr
75
- void PrintFinalMessage(const char* msg);
76
- #endif
77
-
78
- // macro to explicitly ignore the return value from a function call so Code Analysis doesn't complain
79
- #define ORT_IGNORE_RETURN_VALUE(fn) \
80
- static_cast<void>(fn)
81
-
82
- std::vector<std::string> GetStackTrace();
83
- // these is a helper function that gets defined by platform/Telemetry
84
- void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
85
- const char* function, uint32_t line);
86
-
87
- // __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
88
- // so we only define it as one for MSVC
89
- #if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
90
- #define __PRETTY_FUNCTION__ __FUNCTION__
91
- #endif
92
-
93
- // Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__
94
- #define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast<const char*>(__FUNCTION__))
95
-
96
- #define ORT_WHERE_WITH_STACK \
97
- ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast<const char*>(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace())
98
-
99
- #ifdef ORT_NO_EXCEPTIONS
100
-
101
- #define ORT_TRY if (true)
102
- #define ORT_CATCH(x) else if (false)
103
- #define ORT_RETHROW
104
-
105
- // In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
106
- // in the body of the catch statements, it is necessary to wrap the body of the catch statement into
107
- // a lambda function. otherwise the exception referred will be undefined and cause build break
108
- #define ORT_HANDLE_EXCEPTION(func)
109
-
110
- // Throw an exception with optional message.
111
- // NOTE: The arguments get streamed into a string via ostringstream::operator<<
112
- // DO NOT use a printf format string, as that will not work as you expect.
113
- #define ORT_THROW(...) \
114
- do { \
115
- ::onnxruntime::PrintFinalMessage( \
116
- ::onnxruntime::OnnxRuntimeException( \
117
- ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) \
118
- .what()); \
119
- abort(); \
120
- } while (false)
121
-
122
- // Just in order to mark things as not implemented. Do not use in final code.
123
- #define ORT_NOT_IMPLEMENTED(...) \
124
- do { \
125
- ::onnxruntime::PrintFinalMessage( \
126
- ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) \
127
- .what()); \
128
- abort(); \
129
- } while (false)
130
-
131
- // Check condition.
132
- // NOTE: The arguments get streamed into a string via ostringstream::operator<<
133
- // DO NOT use a printf format string, as that will not work as you expect.
134
- #define ORT_ENFORCE(condition, ...) \
135
- do { \
136
- if (!(condition)) { \
137
- ::onnxruntime::PrintFinalMessage( \
138
- ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \
139
- ::onnxruntime::MakeString(__VA_ARGS__)) \
140
- .what()); \
141
- abort(); \
142
- } \
143
- } while (false)
144
-
145
- #define ORT_THROW_EX(ex, ...) \
146
- do { \
147
- ::onnxruntime::PrintFinalMessage( \
148
- ::onnxruntime::MakeString(#ex, "(", ::onnxruntime::MakeString(__VA_ARGS__), ")").c_str()); \
149
- abort(); \
150
- } while (false)
151
-
152
- #else
153
-
154
- #define ORT_TRY try
155
- #define ORT_CATCH(x) catch (x)
156
- #define ORT_RETHROW throw;
157
-
158
- #define ORT_HANDLE_EXCEPTION(func) func()
159
-
160
- // Throw an exception with optional message.
161
- // NOTE: The arguments get streamed into a string via ostringstream::operator<<
162
- // DO NOT use a printf format string, as that will not work as you expect.
163
- #define ORT_THROW(...) \
164
- throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__))
165
-
166
- // Just in order to mark things as not implemented. Do not use in final code.
167
- #define ORT_NOT_IMPLEMENTED(...) \
168
- throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__))
169
-
170
- // Check condition.
171
- // NOTE: The arguments get streamed into a string via ostringstream::operator<<
172
- // DO NOT use a printf format string, as that will not work as you expect.
173
- #define ORT_ENFORCE(condition, ...) \
174
- do { \
175
- if (!(condition)) { \
176
- throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \
177
- ::onnxruntime::MakeString(__VA_ARGS__)); \
178
- } \
179
- } while (false)
180
-
181
- #define ORT_THROW_EX(ex, ...) \
182
- throw ex(__VA_ARGS__)
183
-
184
- #endif
185
-
186
- #define ORT_MAKE_STATUS(category, code, ...) \
187
- ::onnxruntime::common::Status(::onnxruntime::common::category, \
188
- ::onnxruntime::common::code, \
189
- ::onnxruntime::MakeString(__VA_ARGS__))
190
-
191
- // Check condition. if met, return status.
192
- #define ORT_RETURN_IF(condition, ...) \
193
- do { \
194
- if (condition) { \
195
- return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \
196
- ::onnxruntime::common::FAIL, \
197
- ::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \
198
- } \
199
- } while (false)
200
-
201
- // Check condition. if not met, return status.
202
- #define ORT_RETURN_IF_NOT(condition, ...) \
203
- ORT_RETURN_IF(!(condition), __VA_ARGS__)
204
-
205
- // Macros to disable the copy and/or move ctor and assignment methods
206
- // These are usually placed in the private: declarations for a class.
207
-
208
- #define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete
209
-
210
- #define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete
211
-
212
- #define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \
213
- ORT_DISALLOW_COPY(TypeName); \
214
- ORT_DISALLOW_ASSIGNMENT(TypeName)
215
-
216
- #define ORT_DISALLOW_MOVE(TypeName) \
217
- TypeName(TypeName&&) = delete; \
218
- TypeName& operator=(TypeName&&) = delete
219
-
220
- #define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \
221
- ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \
222
- ORT_DISALLOW_MOVE(TypeName)
223
-
224
- #define ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id) \
225
- do { \
226
- auto _status = (expr); \
227
- if ((!_status.IsOK())) { \
228
- ::onnxruntime::LogRuntimeError(session_id, _status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__); \
229
- return _status; \
230
- } \
231
- } while (0)
232
-
233
- #define ORT_RETURN_IF_ERROR_SESSIONID_(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id_)
234
- #define ORT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, 0)
235
-
236
- #define ORT_THROW_IF_ERROR(expr) \
237
- do { \
238
- auto _status = (expr); \
239
- if ((!_status.IsOK())) { \
240
- ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__); \
241
- ORT_THROW(_status); \
242
- } \
243
- } while (0)
244
-
245
- // use this macro when cannot early return
246
- #define ORT_CHECK_AND_SET_RETVAL(expr) \
247
- do { \
248
- if (retval.IsOK()) { \
249
- retval = (expr); \
250
- } \
251
- } while (0)
252
-
253
- inline long long TimeDiffMicroSeconds(TimePoint start_time) {
254
- auto end_time = std::chrono::high_resolution_clock::now();
255
- return std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
256
- }
257
-
258
- inline long long TimeDiffMicroSeconds(TimePoint start_time, TimePoint end_time) {
259
- return std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
260
- }
261
-
262
- struct null_type {};
263
- inline std::string ToUTF8String(const std::string& s) { return s; }
264
- #ifdef _WIN32
265
- /**
266
- * Convert a wide character string to a UTF-8 string
267
- */
268
- std::string ToUTF8String(const std::wstring& s);
269
-
270
- std::wstring ToWideString(const std::string& s);
271
- inline std::wstring ToWideString(const std::wstring& s) { return s; }
272
- #else
273
- inline std::string ToWideString(const std::string& s) { return s; }
274
- #endif
275
-
276
- constexpr size_t kMaxStrLen = 2048;
277
-
278
- // Returns whether `key` is in `container`.
279
- // Like C++20's map/set contains() member function.
280
- template <typename Key, typename... OtherContainerArgs,
281
- template <typename...> typename AssociativeContainer,
282
- typename LookupKey>
283
- inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, LookupKey&& key) {
284
- return container.find(std::forward<LookupKey>(key)) != container.end();
285
- }
286
-
287
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/const_pointer_container.h DELETED
@@ -1,85 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <type_traits>
7
-
8
- namespace onnxruntime {
9
- /**
10
- Container has T* entries. e.g. std::vector<T*>, and this class provides const access to those
11
- via iterators and direct access, as the standard behavior only makes the pointer constant,
12
- and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper.
13
- See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers
14
- */
15
- template <typename Container>
16
- class ConstPointerContainer {
17
- public:
18
- using T = typename std::remove_pointer<typename Container::value_type>::type;
19
-
20
- class ConstIterator {
21
- public:
22
- using const_iterator = typename Container::const_iterator;
23
- using iterator_category = std::input_iterator_tag;
24
- using value_type = T*;
25
- using difference_type = std::ptrdiff_t;
26
- using pointer = T**;
27
- using reference = T*&;
28
-
29
- /** Construct iterator for container that will return const T* entries.*/
30
- explicit ConstIterator(const_iterator position) noexcept : current_{position}, item_{nullptr} {}
31
- ConstIterator(const ConstIterator& other) = default;
32
- ConstIterator& operator=(const ConstIterator& other) = default;
33
-
34
- bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; }
35
- bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; }
36
-
37
- ConstIterator& operator++() {
38
- ++current_;
39
- return *this;
40
- }
41
-
42
- ConstIterator operator++(int) {
43
- ConstIterator tmp{*this};
44
- ++(*this);
45
- return tmp;
46
- }
47
-
48
- const T*& operator*() const {
49
- item_ = *current_;
50
- return item_;
51
- }
52
-
53
- const T** operator->() const { return &(operator*()); };
54
-
55
- private:
56
- const_iterator current_;
57
- mutable const T* item_;
58
- };
59
-
60
- /**
61
- Construct wrapper class that will provide const access to the pointers in a container of non-const pointers.
62
- @param data Container with non-const pointers. e.g. std::vector<T*>
63
- */
64
- explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {}
65
-
66
- size_t size() const noexcept { return data_.size(); }
67
- bool empty() const noexcept { return data_.empty(); }
68
-
69
- ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); }
70
- ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); }
71
-
72
- ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); }
73
- ConstIterator end() const noexcept { return ConstIterator(data_.cend()); }
74
-
75
- const T* operator[](size_t index) const { return data_[index]; }
76
-
77
- const T* at(size_t index) const {
78
- ORT_ENFORCE(index < data_.size());
79
- return data_[index];
80
- }
81
-
82
- private:
83
- const Container& data_;
84
- };
85
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/denormal.h DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- namespace onnxruntime {
7
-
8
- // Set or unset flush-to-zero and denormal=as-zero if SSE3 instructions are supported.
9
- // Return true if SSE3 instruction is supported, otherwise return false.
10
- bool SetDenormalAsZero(bool on);
11
-
12
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/eigen_common_wrapper.h DELETED
@@ -1,62 +0,0 @@
1
- //-----------------------------------------------------------------------------
2
- //
3
- // Copyright (c) Microsoft Corporation. All rights reserved.
4
- // Licensed under the MIT License.
5
- //
6
- //-----------------------------------------------------------------------------
7
- #pragma once
8
- #include "onnxruntime_config.h"
9
- // build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71:
10
- // error: ignoring attributes on template argument "Eigen::PacketType<const float, Eigen::DefaultDevice>::type {aka __vector(4) float}" [-Werror=ignored-attributes]
11
- #if defined(__GNUC__)
12
- #pragma GCC diagnostic push
13
- #if __GNUC__ >= 6
14
- #pragma GCC diagnostic ignored "-Wignored-attributes"
15
- #endif
16
- #pragma GCC diagnostic ignored "-Wunused-parameter"
17
- #pragma GCC diagnostic ignored "-Wunused-result"
18
- #ifdef HAS_DEPRECATED_COPY
19
- #pragma GCC diagnostic ignored "-Wdeprecated-copy"
20
- #endif
21
- // cmake/external/eigen/unsupported/Eigen/CXX11/../../../Eigen/src/Core/arch/NEON/PacketMath.h:1633:9:
22
- // error: ‘void* memcpy(void*, const void*, size_t)’ copying an object of non-trivial type ‘Eigen::internal::Packet4c’
23
- // {aka ‘struct Eigen::internal::eigen_packet_wrapper<int, 2>’} from an array of ‘const int8_t’
24
- // {aka ‘const signed char’} [-Werror=class-memaccess]
25
- #ifdef HAS_CLASS_MEMACCESS
26
- #pragma GCC diagnostic ignored "-Wclass-memaccess"
27
- #endif
28
-
29
- // cmake/external/eigen\Eigen/src/Core/util/Meta.h:454:25:
30
- // error: 'result_of<Eigen::internal::scalar_product_op<unsigned long long> (const unsigned long long &, const unsigned long long &)>'
31
- // is deprecated [-Werror,-Wdeprecated-declarations]
32
- // typedef typename std::result_of<T>::type type1;
33
- #ifdef HAS_DEPRECATED_DECLARATIONS
34
- #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
35
- #endif
36
-
37
- // cmake/external/eigen\Eigen/CXX11/src/Tensor/TensorTrace.h:130:9:
38
- // error: variable 'num_distinct_reduce_dims' set but not used [-Werror,-Wunused-but-set-variable]
39
- // int num_distinct_reduce_dims = 0;
40
- #ifdef HAS_UNUSED_BUT_SET_VARIABLE
41
- #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
42
- #endif
43
-
44
- #elif defined(_MSC_VER)
45
- // build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76):
46
- // warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence
47
-
48
- // unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64'
49
- // to 'uint64_t', signed/unsigned mismatch
50
- #pragma warning(push)
51
- #pragma warning(disable : 4554)
52
- #pragma warning(disable : 4245)
53
- #pragma warning(disable : 4127)
54
- #endif
55
-
56
- #include "unsupported/Eigen/CXX11/Tensor"
57
-
58
- #if defined(__GNUC__)
59
- #pragma GCC diagnostic pop
60
- #elif defined(_MSC_VER)
61
- #pragma warning(pop)
62
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/exceptions.h DELETED
@@ -1,71 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <algorithm>
7
- #include <exception>
8
- #include <iterator>
9
- #include <stdexcept>
10
- #include <string>
11
- #include <vector>
12
-
13
- #include "core/common/common.h"
14
- #include "core/common/code_location.h"
15
-
16
- namespace onnxruntime {
17
-
18
- class NotImplementedException : public std::logic_error {
19
- public:
20
- explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){};
21
- explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){};
22
- };
23
-
24
- class TypeMismatchException : public std::logic_error {
25
- public:
26
- TypeMismatchException() noexcept : logic_error("Type mismatch"){};
27
- };
28
-
29
- class OnnxRuntimeException : public std::exception {
30
- public:
31
- OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept
32
- : OnnxRuntimeException(location, nullptr, msg) {
33
- }
34
-
35
- /**
36
- Create a new exception that captures the location it was thrown from.
37
- @param location Location in the source code the exception is being thrown from
38
- @param failed_condition Optional string containing the condition that failed.
39
- e.g. "tensor.Size() == input.Size()". May be nullptr.
40
- @param msg Message containing additional information about the exception cause.
41
- */
42
- OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg)
43
- : location_{location} {
44
- std::ostringstream ss;
45
-
46
- ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous
47
- if (failed_condition != nullptr) {
48
- ss << " " << failed_condition << " was false.";
49
- }
50
-
51
- ss << " " << msg << "\n";
52
- if (!location.stacktrace.empty()) {
53
- ss << "Stacktrace:\n";
54
- // skip the first entry in the stacktrace as we have that information from location.ToString()
55
- std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator<std::string>(ss, "\n"));
56
- }
57
-
58
- what_ = ss.str();
59
- }
60
-
61
- const char* what() const noexcept override {
62
- return what_.c_str();
63
- }
64
-
65
- private:
66
- const CodeLocation location_;
67
- const std::vector<std::string> stacktrace_;
68
- std::string what_;
69
- };
70
-
71
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/gpu_profiler_common.h DELETED
@@ -1,472 +0,0 @@
1
- #pragma once
2
-
3
- #include "core/common/profiler_common.h"
4
- #include "core/common/inlined_containers.h"
5
-
6
- #include <map>
7
- #include <memory>
8
- #include <mutex>
9
- #include <sstream>
10
- #include <string>
11
- #include <vector>
12
- #include <utility>
13
-
14
- namespace onnxruntime {
15
- namespace profiling {
16
-
17
- // The classes in this header are implemented as template/inline classes
18
- // to avoid having to export symbols from the main onnxruntime shared library
19
- // to ExecutionProvider (EP) shared libraries.
20
- // More context: The main onnxruntime shared library is optimized for size
21
- // using --gc-sections during link time to ensure that any unreferenced code
22
- // is not retained. This poses a problem in using a design pattern where the
23
- // (abstract) base class is implemented in the main onnxruntime shared library,
24
- // but (concrete) subclasses are implemented in EP shared libraries. Now, because
25
- // EP shared libraries are loaded at runtime (as of 11/2022), there will be no
26
- // references to the base class symbols when the main onnxruntime shared library
27
- // is compiled. Thus, the base class symbols will not be included in the
28
- // main onnxruntime shared library. This manifests in being unable to load
29
- // EP shared libs (because the base class symbols referenced by derived
30
- // classes are missing).
31
- // We solve this by implementing base classes that are common to all GPU profilers
32
- // inline in this header.
33
-
34
- class ProfilerActivityBuffer {
35
- public:
36
- ProfilerActivityBuffer() noexcept
37
- : data_(nullptr), size_(0) {}
38
-
39
- ProfilerActivityBuffer(const char* data, size_t size) noexcept
40
- : data_(std::make_unique<char[]>(size)), size_(size) {
41
- memcpy(data_.get(), data, size_);
42
- }
43
-
44
- ProfilerActivityBuffer(const ProfilerActivityBuffer& other) noexcept
45
- : ProfilerActivityBuffer(other.GetData(), other.GetSize()) {}
46
-
47
- ProfilerActivityBuffer(ProfilerActivityBuffer&& other) noexcept
48
- : ProfilerActivityBuffer() {
49
- std::swap(data_, other.data_);
50
- std::swap(size_, other.size_);
51
- }
52
-
53
- ProfilerActivityBuffer& operator=(const ProfilerActivityBuffer& other) noexcept {
54
- if (&other == this) {
55
- return *this;
56
- }
57
-
58
- new (this) ProfilerActivityBuffer{other};
59
- return *this;
60
- }
61
-
62
- ProfilerActivityBuffer& operator=(ProfilerActivityBuffer&& other) noexcept {
63
- if (&other == this) {
64
- return *this;
65
- }
66
-
67
- new (this) ProfilerActivityBuffer{std::move(other)};
68
- return *this;
69
- }
70
-
71
- static ProfilerActivityBuffer CreateFromPreallocatedBuffer(std::unique_ptr<char[]>&& buffer_ptr, size_t size) {
72
- ProfilerActivityBuffer res{};
73
- res.data_ = std::move(buffer_ptr);
74
- res.size_ = size;
75
- return res;
76
- }
77
-
78
- // accessors
79
- char* GetData() { return data_.get(); }
80
- const char* GetData() const { return data_.get(); }
81
- size_t GetSize() const { return size_; }
82
-
83
- private:
84
- std::unique_ptr<char[]> data_;
85
- size_t size_;
86
- }; /* end class ProfilerActivityBuffer */
87
-
88
- template <typename TDerived>
89
- class GPUTracerManager {
90
- public:
91
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GPUTracerManager);
92
- virtual ~GPUTracerManager() {}
93
-
94
- uint64_t RegisterClient() {
95
- std::lock_guard<std::mutex> lock(manager_instance_mutex_);
96
- auto res = next_client_id_++;
97
- per_client_events_by_ext_correlation_.insert({res, {}});
98
- ++num_active_clients_;
99
- return res;
100
- }
101
-
102
- void DeregisterClient(uint64_t client_handle) {
103
- std::lock_guard<std::mutex> lock(manager_instance_mutex_);
104
- auto it = per_client_events_by_ext_correlation_.find(client_handle);
105
- if (it == per_client_events_by_ext_correlation_.end()) {
106
- return;
107
- }
108
- per_client_events_by_ext_correlation_.erase(it);
109
- --num_active_clients_;
110
- if (num_active_clients_ == 0 && tracing_enabled_) {
111
- StopLogging();
112
- }
113
- }
114
-
115
- void StartLogging() {
116
- std::lock_guard<std::mutex> lock(manager_instance_mutex_);
117
- if (tracing_enabled_) {
118
- return;
119
- }
120
-
121
- auto this_as_derived = static_cast<TDerived*>(this);
122
- tracing_enabled_ = this_as_derived->OnStartLogging();
123
- }
124
-
125
- void Consume(uint64_t client_handle, const TimePoint& start_time, std::map<uint64_t, Events>& events) {
126
- auto this_as_derived = static_cast<TDerived*>(this);
127
- events.clear();
128
- {
129
- // Flush any pending activity records before starting
130
- // to process the accumulated activity records.
131
- std::lock_guard<std::mutex> lock_manager(manager_instance_mutex_);
132
- if (!tracing_enabled_) {
133
- return;
134
- }
135
-
136
- this_as_derived->FlushActivities();
137
- }
138
-
139
- std::vector<ProfilerActivityBuffer> activity_buffers;
140
- {
141
- std::lock_guard<std::mutex> lock(unprocessed_activity_buffers_mutex_);
142
- std::swap(unprocessed_activity_buffers_, activity_buffers);
143
- unprocessed_activity_buffers_.clear();
144
- }
145
-
146
- {
147
- // Ensure that at most one thread is working through the activity buffers at any time.
148
- std::lock_guard<std::mutex> lock_two(activity_buffer_processor_mutex_);
149
- this_as_derived->ProcessActivityBuffers(activity_buffers, start_time);
150
- auto it = per_client_events_by_ext_correlation_.find(client_handle);
151
- if (it == per_client_events_by_ext_correlation_.end()) {
152
- return;
153
- }
154
- std::swap(events, it->second);
155
- }
156
- }
157
-
158
- void PushCorrelation(uint64_t client_handle,
159
- uint64_t external_correlation_id,
160
- TimePoint profiling_start_time) {
161
- auto this_as_derived = static_cast<TDerived*>(this);
162
- std::lock_guard<std::mutex> lock(manager_instance_mutex_);
163
- if (!tracing_enabled_) {
164
- return;
165
- }
166
-
167
- auto it = per_client_events_by_ext_correlation_.find(client_handle);
168
- if (it == per_client_events_by_ext_correlation_.end()) {
169
- // not a registered client, do nothing
170
- return;
171
- }
172
-
173
- // external_correlation_id is simply the timestamp of this event,
174
- // relative to profiling_start_time. i.e., it was computed as:
175
- // external_correlation_id =
176
- // std::chrono::duration_cast<std::chrono::microseconds>(event_start_time - profiling_start_time).count()
177
- //
178
- // Because of the relative nature of the external_correlation_id, the same
179
- // external_correlation_id can be reused across different clients, which then makes it
180
- // impossible to recover the client from the external_correlation_id, which in turn
181
- // makes it impossible to map events (which are tagged with external_correlation_id) to clients.
182
- //
183
- // To address these difficulties, we construct a new correlation_id (let's call it unique_cid)
184
- // as follows:
185
- // unique_cid =
186
- // external_correlation_id +
187
- // std::chrono::duration_cast<std::chrono::microseconds>(profiling_start_time.time_since_epoch()).count()
188
- // now, unique_cid is monotonically increasing with time, so it can be used to reliably map events to clients.
189
- //
190
- // Of course, clients expect lists of events to be returned (on a call to Consume()), that are
191
- // still keyed on the external_correlation_id that they've specified here, so we need to remember the
192
- // offset to be subtracted
193
- uint64_t offset = std::chrono::duration_cast<std::chrono::microseconds>(profiling_start_time.time_since_epoch()).count();
194
- auto unique_cid = external_correlation_id + offset;
195
- unique_correlation_id_to_client_offset_[unique_cid] = std::make_pair(client_handle, offset);
196
- this_as_derived->PushUniqueCorrelation(unique_cid);
197
- }
198
-
199
- void PopCorrelation(uint64_t& popped_external_correlation_id) {
200
- auto this_as_derived = static_cast<TDerived*>(this);
201
- std::lock_guard<std::mutex> lock(manager_instance_mutex_);
202
- if (!tracing_enabled_) {
203
- return;
204
- }
205
- uint64_t unique_cid;
206
- this_as_derived->PopUniqueCorrelation(unique_cid);
207
- // lookup the offset and subtract it before returning popped_external_correlation_id to the client
208
- auto client_it = unique_correlation_id_to_client_offset_.find(unique_cid);
209
- if (client_it == unique_correlation_id_to_client_offset_.end()) {
210
- popped_external_correlation_id = 0;
211
- return;
212
- }
213
- popped_external_correlation_id = unique_cid - client_it->second.second;
214
- }
215
-
216
- void PopCorrelation() {
217
- uint64_t unused;
218
- PopCorrelation(unused);
219
- }
220
-
221
- protected:
222
- GPUTracerManager() {
223
- auto this_as_derived = static_cast<TDerived*>(this);
224
- uint64_t gpu_ts1, gpu_ts2, cpu_ts;
225
-
226
- // Get the CPU and GPU timestamps to warm up
227
- gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds();
228
- cpu_ts = this->GetCPUTimestampInNanoseconds();
229
-
230
- // Estimate the skew/offset between the CPU and GPU timestamps.
231
- gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds();
232
- cpu_ts = this->GetCPUTimestampInNanoseconds();
233
- gpu_ts2 = this_as_derived->GetGPUTimestampInNanoseconds();
234
-
235
- auto gpu_ts = (gpu_ts1 + gpu_ts2) / 2;
236
- offset_to_add_to_gpu_timestamps_ = cpu_ts - gpu_ts;
237
- }
238
-
239
- #if 0
240
- // Functional API to be implemented by subclasses
241
- // Included here only for documentation purposes
242
- protected:
243
- bool OnStartLogging();
244
- void OnStopLogging();
245
- void ProcessActivityBuffers(const std::vector<ProfilerActivityBuffer>& buffers,
246
- const TimePoint& start_time);
247
- bool PushUniqueCorrelation(uint64_t unique_cid);
248
- void PopUniqueCorrelation(uint64_t& popped_unique_cid);
249
- void FlushActivities();
250
- uint64_t GetGPUTimestampInNanoseconds();
251
- #endif
252
-
253
- void EnqueueActivityBuffer(ProfilerActivityBuffer&& buffer) {
254
- std::lock_guard<std::mutex> lock(unprocessed_activity_buffers_mutex_);
255
- unprocessed_activity_buffers_.emplace_back(std::move(buffer));
256
- }
257
-
258
- // To be called by subclasses only from ProcessActivityBuffers
259
- void MapEventToClient(uint64_t tracer_correlation_id, EventRecord&& event) {
260
- auto it = tracer_correlation_to_unique_correlation_.find(tracer_correlation_id);
261
- if (it == tracer_correlation_to_unique_correlation_.end()) {
262
- // We're yet to receive a mapping to unique_correlation_id for this tracer_correlation_id
263
- DeferEventMapping(std::move(event), tracer_correlation_id);
264
- return;
265
- }
266
- auto unique_correlation_id = it->second;
267
- auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id);
268
- if (p_event_list != nullptr) {
269
- p_event_list->emplace_back(std::move(event));
270
- }
271
- }
272
-
273
- // To be called by subclasses only from ProcessActivityBuffers
274
- void NotifyNewCorrelation(uint64_t tracer_correlation_id, uint64_t unique_correlation_id) {
275
- tracer_correlation_to_unique_correlation_[tracer_correlation_id] = unique_correlation_id;
276
- auto pending_it = events_pending_client_mapping_.find(tracer_correlation_id);
277
- if (pending_it == events_pending_client_mapping_.end()) {
278
- return;
279
- }
280
- // Map the pending events to the right client
281
- MapEventsToClient(unique_correlation_id, std::move(pending_it->second));
282
- events_pending_client_mapping_.erase(pending_it);
283
- }
284
-
285
- uint64_t NormalizeGPUTimestampToCPUEpoch(uint64_t gpu_timestamp_in_nanoseconds) {
286
- return gpu_timestamp_in_nanoseconds + this->offset_to_add_to_gpu_timestamps_;
287
- }
288
-
289
- private:
290
- // Requires: manager_instance_mutex_ should be held
291
- void StopLogging() {
292
- auto this_as_derived = static_cast<TDerived*>(this);
293
- if (!tracing_enabled_) {
294
- return;
295
- }
296
- this_as_derived->OnStopLogging();
297
- tracing_enabled_ = false;
298
- Clear();
299
- }
300
-
301
- // Requires: manager_instance_mutex_ should be held
302
- void Clear() {
303
- unprocessed_activity_buffers_.clear();
304
- unique_correlation_id_to_client_offset_.clear();
305
- per_client_events_by_ext_correlation_.clear();
306
- tracer_correlation_to_unique_correlation_.clear();
307
- events_pending_client_mapping_.clear();
308
- }
309
-
310
- Events* GetEventListForUniqueCorrelationId(uint64_t unique_correlation_id) {
311
- auto client_it = unique_correlation_id_to_client_offset_.find(unique_correlation_id);
312
- if (client_it == unique_correlation_id_to_client_offset_.end()) {
313
- return nullptr;
314
- }
315
-
316
- // See the comments on the GetUniqueCorrelationId method for an explanation of
317
- // of this offset computation and why it's required.
318
- auto const& client_handle_offset = client_it->second;
319
- auto external_correlation = unique_correlation_id - client_handle_offset.second;
320
- auto& event_list = per_client_events_by_ext_correlation_[client_handle_offset.first][external_correlation];
321
- return &event_list;
322
- }
323
-
324
- void MapEventsToClient(uint64_t unique_correlation_id, std::vector<EventRecord>&& events) {
325
- auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id);
326
- if (p_event_list != nullptr) {
327
- p_event_list->insert(p_event_list->end(),
328
- std::make_move_iterator(events.begin()),
329
- std::make_move_iterator(events.end()));
330
- }
331
- }
332
-
333
- void DeferEventMapping(EventRecord&& event, uint64_t tracer_correlation_id) {
334
- events_pending_client_mapping_[tracer_correlation_id].emplace_back(std::move(event));
335
- }
336
-
337
- uint64_t GetCPUTimestampInNanoseconds() {
338
- return std::chrono::duration_cast<std::chrono::nanoseconds>(
339
- std::chrono::high_resolution_clock::now().time_since_epoch())
340
- .count();
341
- }
342
-
343
- std::mutex manager_instance_mutex_;
344
- uint64_t next_client_id_ = 1;
345
- uint64_t num_active_clients_ = 0;
346
- bool tracing_enabled_ = false;
347
- std::mutex unprocessed_activity_buffers_mutex_;
348
- std::mutex activity_buffer_processor_mutex_;
349
-
350
- // Unprocessed activity buffers
351
- std::vector<ProfilerActivityBuffer> unprocessed_activity_buffers_;
352
-
353
- // Keyed on unique_correlation_id -> (client_id/client_handle, offset)
354
- // unique_correlation_id - offset == external_correlation_id
355
- InlinedHashMap<uint64_t, std::pair<uint64_t, uint64_t>> unique_correlation_id_to_client_offset_;
356
-
357
- // Keyed on tracer_correlation_id -> unique_correlation_id
358
- InlinedHashMap<uint64_t, uint64_t> tracer_correlation_to_unique_correlation_;
359
-
360
- // client_id/client_handle -> external_correlation_id -> events
361
- InlinedHashMap<uint64_t, std::map<uint64_t, Events>> per_client_events_by_ext_correlation_;
362
-
363
- // Keyed on tracer correlation_id, keeps track of activity records
364
- // for which we haven't established the external_correlation_id yet.
365
- InlinedHashMap<uint64_t, std::vector<EventRecord>> events_pending_client_mapping_;
366
-
367
- // An offset to add to (the possibly skewed) GPU timestamps
368
- // to normalize GPU timestamps with CPU timestamps
369
- int64_t offset_to_add_to_gpu_timestamps_;
370
- }; /* class GPUTracerManager */
371
-
372
- // Base class for a GPU profiler
373
- template <typename TManager>
374
- class GPUProfilerBase : public EpProfiler {
375
- protected:
376
- GPUProfilerBase() = default;
377
- virtual ~GPUProfilerBase() {}
378
-
379
- void MergeEvents(std::map<uint64_t, Events>& events_to_merge, Events& events) {
380
- Events merged_events;
381
-
382
- auto event_iter = std::make_move_iterator(events.begin());
383
- auto event_end = std::make_move_iterator(events.end());
384
- for (auto& map_iter : events_to_merge) {
385
- if (map_iter.second.empty()) {
386
- continue;
387
- }
388
-
389
- auto ts = static_cast<long long>(map_iter.first);
390
-
391
- // find the last occurrence of a matching timestamp,
392
- // if one exists
393
- while (event_iter != event_end &&
394
- (event_iter->ts < ts ||
395
- (event_iter->ts == ts &&
396
- (event_iter + 1) != event_end &&
397
- (event_iter + 1)->ts == ts))) {
398
- merged_events.emplace_back(*event_iter);
399
- ++event_iter;
400
- }
401
-
402
- bool copy_op_names = false;
403
- std::string op_name;
404
- std::string parent_name;
405
-
406
- if (event_iter != event_end && event_iter->ts == ts) {
407
- // We've located a parent event, copy the op_name and set
408
- // this event's parent_name property to the name of the parent.
409
- copy_op_names = true;
410
- op_name = event_iter->args["op_name"];
411
- parent_name = event_iter->name;
412
- merged_events.emplace_back(*event_iter);
413
- ++event_iter;
414
- }
415
-
416
- for (auto& evt : map_iter.second) {
417
- if (copy_op_names) {
418
- // If we have found a matching parent event,
419
- // then inherit some names from the parent.
420
- evt.args["op_name"] = op_name;
421
- evt.args["parent_name"] = parent_name;
422
- }
423
- }
424
-
425
- merged_events.insert(merged_events.end(),
426
- std::make_move_iterator(map_iter.second.begin()),
427
- std::make_move_iterator(map_iter.second.end()));
428
- }
429
-
430
- // move any remaining events
431
- merged_events.insert(merged_events.end(), event_iter, event_end);
432
- std::swap(events, merged_events);
433
- }
434
-
435
- uint64_t client_handle_;
436
- TimePoint profiling_start_time_;
437
-
438
- public:
439
- virtual bool StartProfiling(TimePoint profiling_start_time) override {
440
- auto& manager = TManager::GetInstance();
441
- manager.StartLogging();
442
- profiling_start_time_ = profiling_start_time;
443
- return true;
444
- }
445
-
446
- virtual void EndProfiling(TimePoint start_time, Events& events) override {
447
- auto& manager = TManager::GetInstance();
448
- std::map<uint64_t, Events> event_map;
449
- manager.Consume(client_handle_, start_time, event_map);
450
- MergeEvents(event_map, events);
451
- }
452
-
453
- virtual void Start(uint64_t id) override {
454
- auto& manager = TManager::GetInstance();
455
- manager.PushCorrelation(client_handle_, id, profiling_start_time_);
456
- }
457
-
458
- virtual void Stop(uint64_t) override {
459
- auto& manager = TManager::GetInstance();
460
- manager.PopCorrelation();
461
- }
462
- }; /* class GPUProfilerBase */
463
-
464
- // Convert a pointer to a hex string
465
- static inline std::string PointerToHexString(const void* ptr) {
466
- std::ostringstream sstr;
467
- sstr << std::hex << ptr;
468
- return sstr.str();
469
- }
470
-
471
- } /* end namespace profiling */
472
- } /* end namespace onnxruntime */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/gsl.h DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "gsl/gsl"
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/hash_combine.h DELETED
@@ -1,21 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- namespace onnxruntime {
7
-
8
- // Combine hash value `seed` with hash value `h`, updating `seed` in place.
9
- // TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine()
10
- inline void HashCombineWithHashValue(size_t h, size_t& seed) {
11
- seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2);
12
- }
13
-
14
- // Combine hash value `seed` with the hash value of `value`, updating `seed` in place.
15
- // The hash value computation is specified by the `Hash` template parameter.
16
- template <typename T, typename Hash = std::hash<T>>
17
- inline void HashCombine(const T& value, size_t& seed) {
18
- HashCombineWithHashValue(Hash{}(value), seed);
19
- }
20
-
21
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/inlined_containers.h DELETED
@@ -1,175 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <cmath>
7
-
8
- #include "core/common/inlined_containers_fwd.h"
9
-
10
- #ifndef DISABLE_ABSEIL
11
-
12
- #ifdef _MSC_VER
13
- #pragma warning(push)
14
- // C4127: conditional expression is constant
15
- #pragma warning(disable : 4127)
16
- // C4324: structure was padded due to alignment specifier
17
- // Usage of alignas causes some internal padding in places.
18
- #pragma warning(disable : 4324)
19
- #endif // _MSC_VER
20
-
21
- #include <absl/container/flat_hash_set.h>
22
- #include <absl/container/flat_hash_map.h>
23
-
24
- #include <absl/container/node_hash_set.h>
25
- #include <absl/container/node_hash_map.h>
26
-
27
- #ifdef _MSC_VER
28
- #pragma warning(pop)
29
- #endif // _MSC_VER
30
-
31
- #else // DISABLE_ABSEIL
32
-
33
- #include <unordered_set>
34
- #include <unordered_map>
35
- #include <set>
36
- #include <map>
37
-
38
- #endif // DISABLE_ABSEIL
39
-
40
- namespace onnxruntime {
41
-
42
- #ifndef DISABLE_ABSEIL
43
- // InlinedHashSet and InlinedHashMap are preferred
44
- // hash based containers. They store their values in the
45
- // buckets array that is allocated in one shot. It eliminates
46
- // per-node new/delete calls. Always call reserve() on any hash set/map
47
- // when the number of items is known in advance.
48
- // This does not allocate a dummy 'end' node on default construction.
49
- template <typename T, typename Allocator>
50
- class InlinedHashSet : public absl::flat_hash_set<T,
51
- absl::container_internal::hash_default_hash<T>,
52
- absl::container_internal::hash_default_eq<T>,
53
- Allocator> {
54
- using Base = absl::flat_hash_set<T,
55
- absl::container_internal::hash_default_hash<T>,
56
- absl::container_internal::hash_default_eq<T>,
57
- Allocator>;
58
-
59
- public:
60
- using Base::Base;
61
- };
62
-
63
- template <typename Key, typename Value,
64
- typename Allocator>
65
- class InlinedHashMap : public absl::flat_hash_map<Key, Value,
66
- absl::container_internal::hash_default_hash<Key>,
67
- absl::container_internal::hash_default_eq<Key>,
68
- Allocator> {
69
- using Base = absl::flat_hash_map<Key, Value,
70
- absl::container_internal::hash_default_hash<Key>,
71
- absl::container_internal::hash_default_eq<Key>,
72
- Allocator>;
73
-
74
- public:
75
- using Base::Base;
76
- };
77
-
78
- // Use this hash set/map where pointer stability is required, otherwise use
79
- // InlinedHashSet and InlinedHashMap
80
- // This does not allocate a dummy 'end' node on default construction.
81
- // Use reserve() when the number of elements is known.
82
- template <typename T, typename Allocator>
83
- class NodeHashSet : public absl::node_hash_set<T,
84
- absl::container_internal::hash_default_hash<T>,
85
- absl::container_internal::hash_default_eq<T>,
86
- Allocator> {
87
- using Base = absl::node_hash_set<T,
88
- absl::container_internal::hash_default_hash<T>,
89
- absl::container_internal::hash_default_eq<T>,
90
- Allocator>;
91
-
92
- public:
93
- using Base::Base;
94
- };
95
-
96
- template <typename Key, typename Value, typename Allocator>
97
- class NodeHashMap : public absl::node_hash_map<Key, Value,
98
- absl::container_internal::hash_default_hash<Key>,
99
- absl::container_internal::hash_default_eq<Key>,
100
- Allocator> {
101
- using Base = absl::node_hash_map<Key, Value,
102
- absl::container_internal::hash_default_hash<Key>,
103
- absl::container_internal::hash_default_eq<Key>,
104
- Allocator>;
105
-
106
- public:
107
- using Base::Base;
108
- };
109
-
110
- #else // DISABLE_ABSEIL
111
-
112
- template <typename T, typename Allocator>
113
- class InlinedHashSet : public std::unordered_set<T,
114
- std::hash<T>,
115
- std::equal_to<T>,
116
- Allocator> {
117
- using Base = std::unordered_set<T,
118
- std::hash<T>,
119
- std::equal_to<T>,
120
- Allocator>;
121
-
122
- public:
123
- using Base::Base;
124
- };
125
-
126
- template <typename Key, typename Value,
127
- typename Allocator>
128
- class InlinedHashMap : public std::unordered_map<Key, Value,
129
- std::hash<Key>,
130
- std::equal_to<Key>,
131
- Allocator> {
132
- using Base = std::unordered_map<Key, Value,
133
- std::hash<Key>,
134
- std::equal_to<Key>,
135
- Allocator>;
136
-
137
- public:
138
- using Base::Base;
139
- };
140
-
141
- // Use this hash set/map where pointer stability is required, otherwise use
142
- // InlinedHashSet and InlinedHashMap
143
- // This does not allocate a dummy 'end' node on default construction.
144
- // Use reserve() when the number of elements is known.
145
- template <typename T, typename Allocator>
146
- class NodeHashSet : public std::unordered_set<T,
147
- std::hash<T>,
148
- std::equal_to<T>,
149
- Allocator> {
150
- using Base = std::unordered_set<T,
151
- std::hash<T>,
152
- std::equal_to<T>,
153
- Allocator>;
154
-
155
- public:
156
- using Base::Base;
157
- };
158
-
159
- template <typename Key, typename Value, typename Allocator>
160
- class NodeHashMap : public std::unordered_map<Key, Value,
161
- std::hash<Key>,
162
- std::equal_to<Key>,
163
- Allocator> {
164
- using Base = std::unordered_map<Key, Value,
165
- std::hash<Key>,
166
- std::equal_to<Key>,
167
- Allocator>;
168
-
169
- public:
170
- using Base::Base;
171
- };
172
-
173
- #endif // DISABLE_ABSEIL
174
-
175
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/inlined_containers_fwd.h DELETED
@@ -1,147 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <memory>
7
- #include <utility>
8
-
9
- #ifndef DISABLE_ABSEIL
10
- #ifdef _MSC_VER
11
- #pragma warning(push)
12
- // C4127: conditional expression is constant
13
- #pragma warning(disable : 4127)
14
- // C4324: structure was padded due to alignment specifier
15
- // Usage of alignas causes some internal padding in places.
16
- #pragma warning(disable : 4324)
17
- #else
18
- // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102329#c2
19
- #if !defined(__clang__) && defined(__GNUC__)
20
- #pragma GCC diagnostic push
21
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
22
- #endif
23
- #endif // _MSC_VER
24
-
25
- #include <absl/container/inlined_vector.h>
26
-
27
- #ifdef _MSC_VER
28
- #pragma warning(pop)
29
- #else
30
- #if !defined(__clang__) && defined(__GNUC__)
31
- #pragma GCC diagnostic pop
32
- #endif
33
- #endif // _MSC_VER
34
-
35
- #else
36
-
37
- #include <vector>
38
-
39
- #endif // DISABLE_ABSEIL
40
-
41
- // Forward declarations for contexts where abseil can not be compiled and
42
- // not really needed but we want to have it in the headers that are included
43
- // e.g. CUDA 10 and .CU files
44
- // InlinedVector seems to be fine with old CUDA
45
-
46
- //===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===//
47
- //
48
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49
- // See https://llvm.org/LICENSE.txt for license information.
50
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
51
- //
52
- // This file contains code and comments derived from llvm/ADT/SmallVector.h
53
- //
54
- // Specifically CalculateInlinedVectorDefaultInlinedElements<T>() template is derived from
55
- // CalculateSmallVectorDefaultInlinedElements<T>() and its comments.
56
-
57
- namespace onnxruntime {
58
- #ifndef DISABLE_ABSEIL
59
- /// Inspired by LLVM SmallVector with ONNX Runtime adjustments for abseil.
60
- ///
61
- /// Helper class for calculating the default number of inline elements for
62
- /// `InlinedVector<T>`.
63
- /// This produces the following on MSVC x64
64
- /// int8_t -> 41
65
- // int16_t -> 21
66
- // int32_t -> 11
67
- // int64_t -> 6
68
- // std::string 40 -> 1
69
- template <typename T>
70
- struct CalculateInlinedVectorDefaultInlinedElements {
71
- // Parameter controlling the default number of inlined elements
72
- // for `InlinedVector<T>`.
73
- //
74
- // The default number of inlined elements ensures that
75
- // 1. There is at least one inlined element.
76
- // 2. `sizeof(InlinedVector<T>) <= kPreferredInlinedVectorSizeof` unless
77
- // it contradicts 1.
78
- static constexpr size_t kPreferredInlinedVectorSizeof = 64;
79
-
80
- // static_assert that sizeof(T) is not "too big".
81
- //
82
- // Because the InlinedVector must have at least one inlined element, it is possible
83
- // for an arbitrarily large inlined element to allocate an arbitrarily large
84
- // amount of inline storage. So we want to call attention to these cases and
85
- // make sure that users are making an intentional decision if they request a lot of inline storage.
86
- //
87
- // We want this assertion to trigger in pathological cases, but otherwise
88
- // not be too easy to hit. To accomplish that, the cutoff is actually somewhat
89
- // larger than kPreferredInlinedVectorSizeof (otherwise,
90
- // `InlinedVector<InlinedVector<T>>` would be one easy way to trip it, and that
91
- // pattern seems useful in practice).
92
- //
93
- // One wrinkle is that this assertion is in theory non-portable, since
94
- // sizeof(absl::InlinedVector<T, 1>) is in general platform-dependent. However, we don't expect this
95
- // to be much of an issue, because most LLVM development happens on 64-bit
96
- // hosts, and therefore sizeof(T) is expected to *decrease* when compiled for
97
- // 32-bit hosts, dodging the issue. The reverse situation, where development
98
- // happens on a 32-bit host and then fails due to sizeof(T) *increasing* on a
99
- // 64-bit host, is expected to be very rare.
100
- static_assert(
101
- sizeof(absl::InlinedVector<T, 1>) <= kPreferredInlinedVectorSizeof,
102
- "You are trying to use a default number of inlined elements for "
103
- "`InlinedVector<T>` but `sizeof(T)` is really big! Please use an "
104
- "explicit number of inlined elements with `InlinedVector<T, N>` to make "
105
- "sure you really want that much inline storage.");
106
-
107
- // Discount the size of the header itself when calculating the maximum inline
108
- // bytes.
109
- static constexpr size_t PreferredInlineBytes =
110
- kPreferredInlinedVectorSizeof - (sizeof(absl::InlinedVector<T, 1>) - sizeof(T));
111
- static constexpr size_t NumElementsThatFit = PreferredInlineBytes / sizeof(T);
112
- static constexpr size_t value =
113
- NumElementsThatFit == 0 ? 1 : NumElementsThatFit;
114
- };
115
-
116
- // Use InlinedVector for small arrays that can fit on a stack with a default
117
- // value pre-calculated.
118
- // Use TensorShapeVector for shapes.
119
- template <typename T,
120
- size_t N = CalculateInlinedVectorDefaultInlinedElements<T>::value,
121
- typename Allocator = std::allocator<T>>
122
- using InlinedVector = absl::InlinedVector<T, N, Allocator>;
123
-
124
- #else
125
-
126
- template <typename T,
127
- size_t N = 0,
128
- typename Allocator = std::allocator<T>>
129
- using InlinedVector = std::vector<T, Allocator>;
130
-
131
- #endif // DISABLE_ABSEIL
132
-
133
- template <typename T,
134
- typename Allocator = std::allocator<T>>
135
- class InlinedHashSet;
136
-
137
- template <typename Key, typename Value,
138
- typename Allocator = std::allocator<std::pair<const Key, Value>>>
139
- class InlinedHashMap;
140
-
141
- template <typename T, typename Allocator = std::allocator<T>>
142
- class NodeHashSet;
143
-
144
- template <typename Key, typename Value,
145
- typename Allocator = std::allocator<std::pair<const Key, Value>>>
146
- class NodeHashMap;
147
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/logging/capture.h DELETED
@@ -1,115 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <cstdarg>
7
- #include "core/common/gsl.h"
8
- #include "core/common/common.h"
9
- #include "core/common/code_location.h"
10
- #include "core/common/logging/severity.h"
11
-
12
- namespace onnxruntime {
13
- namespace logging {
14
-
15
- class Logger;
16
- enum class DataType;
17
-
18
- /**
19
- Class to capture the details of a log message.
20
- */
21
- class Capture {
22
- public:
23
- /**
24
- Initializes a new instance of the Capture class.
25
- @param logger The logger.
26
- @param severity The severity.
27
- @param category The category.
28
- @param dataType Type of the data.
29
- @param location The file location the log message is coming from.
30
- */
31
- Capture(const Logger& logger, logging::Severity severity, const char* category,
32
- logging::DataType dataType, const CodeLocation& location)
33
- : logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} {
34
- }
35
-
36
- /**
37
- The stream that can capture the message via operator<<.
38
- @returns Output stream.
39
- */
40
- std::ostream& Stream() noexcept {
41
- return stream_;
42
- }
43
-
44
- #ifdef _MSC_VER
45
- // add SAL annotation for printf format string. requires Code Analysis to run to validate usage.
46
- #define msvc_printf_check _Printf_format_string_
47
- #define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang.
48
- #else
49
- #define msvc_printf_check
50
- #endif
51
-
52
- /**
53
- Captures a printf style log message.
54
- @param name="format">The printf format.
55
- @param name="">Arguments to the printf format if needed.
56
- @remarks
57
- A maximum of 2K of output will be captured currently.
58
- Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3)
59
- */
60
- void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3)));
61
-
62
- /**
63
- Process a printf style log message.
64
- @param format The printf format.
65
- @param ... Arguments to the printf format if needed.
66
- @remarks
67
- A maximum of 2K of output will be captured currently.
68
- Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf
69
- so that something like "One string: %s", "the string" does not consider "the string"
70
- to be the va_list.
71
- */
72
- void ProcessPrintf(msvc_printf_check const char* format, va_list args);
73
-
74
- logging::Severity Severity() const noexcept {
75
- return severity_;
76
- }
77
-
78
- char SeverityPrefix() const noexcept {
79
- // Carefully setup so severity_ is a valid index
80
- GSL_SUPPRESS(bounds .2) {
81
- return logging::SEVERITY_PREFIX[static_cast<int>(severity_)];
82
- }
83
- }
84
-
85
- const char* Category() const noexcept {
86
- return category_;
87
- }
88
-
89
- logging::DataType DataType() const noexcept {
90
- return data_type_;
91
- }
92
-
93
- const CodeLocation& Location() const noexcept {
94
- return location_;
95
- }
96
-
97
- std::string Message() const noexcept {
98
- return stream_.str();
99
- }
100
-
101
- ~Capture();
102
-
103
- private:
104
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture);
105
-
106
- const Logger* logger_;
107
- const logging::Severity severity_;
108
- const char* category_;
109
- const logging::DataType data_type_;
110
- const CodeLocation location_;
111
-
112
- std::ostringstream stream_;
113
- };
114
- } // namespace logging
115
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/logging/isink.h DELETED
@@ -1,41 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string>
7
-
8
- #include "core/common/logging/logging.h"
9
-
10
- namespace onnxruntime {
11
- namespace logging {
12
- class ISink {
13
- public:
14
- ISink() = default;
15
-
16
- /**
17
- Sends the message to the sink.
18
- @param timestamp The timestamp.
19
- @param logger_id The logger identifier.
20
- @param message The captured message.
21
- */
22
- void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) {
23
- SendImpl(timestamp, logger_id, message);
24
- }
25
-
26
- /**
27
- Sends a Profiling Event Record to the sink.
28
- @param Profiling Event Record
29
- */
30
- virtual void SendProfileEvent(profiling::EventRecord&) const {};
31
-
32
- virtual ~ISink() = default;
33
-
34
- private:
35
- // Make Code Analysis happy by disabling all for now. Enable as needed.
36
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink);
37
-
38
- virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0;
39
- };
40
- } // namespace logging
41
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/logging/logging.h DELETED
@@ -1,337 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <atomic>
7
- #include <chrono>
8
- #include <climits>
9
- #include <map>
10
- #include <memory>
11
- #include <mutex>
12
- #include <string>
13
-
14
- #include "core/common/common.h"
15
- #include "core/common/profiler_common.h"
16
- #include "core/common/logging/capture.h"
17
- #include "core/common/logging/severity.h"
18
-
19
- #include "core/common/logging/macros.h"
20
-
21
- /*
22
-
23
- Logging overview and expected usage:
24
-
25
- At program startup:
26
- * Create one or more ISink instances. If multiple, combine using composite_sink.
27
- * Create a LoggingManager instance with the sink/s with is_default_instance set to true
28
- * Only one instance should be created in this way, and it should remain valid for
29
- until the program no longer needs to produce log output.
30
-
31
- You can either use the static default Logger which LoggingManager will create when constructed
32
- via LoggingManager::DefaultLogger(), or separate Logger instances each with different log ids
33
- via LoggingManager::CreateLogger.
34
-
35
- The log id is passed to the ISink instance with the sink determining how the log id is used
36
- in the output.
37
-
38
- LoggingManager
39
- * creates the Logger instances used by the application
40
- * provides a static default logger instance
41
- * owns the log sink instance
42
- * applies checks on severity and output of user data
43
-
44
- The log macros create a Capture instance to capture the information to log.
45
- If the severity and/or user filtering settings would prevent logging, no evaluation
46
- of the log arguments will occur, so no performance cost beyond the severity and user
47
- filtering check.
48
-
49
- A sink can do further filter as needed.
50
-
51
- */
52
-
53
- namespace onnxruntime {
54
-
55
- namespace logging {
56
-
57
- using Timestamp = std::chrono::time_point<std::chrono::system_clock>;
58
-
59
- #ifndef NDEBUG
60
- ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs.
61
- #else
62
- constexpr bool vlog_enabled = false; // no VLOG output
63
- #endif
64
-
65
- enum class DataType {
66
- SYSTEM = 0, ///< System data.
67
- USER = 1 ///< Contains potentially sensitive user data.
68
- };
69
-
70
- // Internal log categories.
71
- // Logging interface takes const char* so arbitrary values can also be used.
72
- struct Category {
73
- static const char* onnxruntime; ///< General output
74
- static const char* System; ///< Log output regarding interactions with the host system
75
- // TODO: What other high level categories are meaningful? Model? Optimizer? Execution?
76
- };
77
-
78
- class ISink;
79
- class Logger;
80
- class Capture;
81
-
82
- /// <summary>
83
- /// The logging manager.
84
- /// Owns the log sink and potentially provides a default Logger instance.
85
- /// Provides filtering based on a minimum LogSeverity level, and of messages with DataType::User if enabled.
86
- /// </summary>
87
- class LoggingManager final {
88
- public:
89
- enum InstanceType {
90
- Default, ///< Default instance of LoggingManager that should exist for the lifetime of the program
91
- Temporal ///< Temporal instance. CreateLogger(...) should be used, however DefaultLogger() will NOT be provided via this instance.
92
- };
93
-
94
- /**
95
- Initializes a new instance of the LoggingManager class.
96
- @param sink The sink to write to. Use CompositeSink if you need to write to multiple places.
97
- @param default_min_severity The default minimum severity. Messages with lower severity will be ignored unless
98
- overridden in CreateLogger.
99
- @param default_filter_user_data If set to true ignore messages with DataType::USER unless overridden in CreateLogger.
100
- @param instance_type If InstanceType::Default, this is the default instance of the LoggingManager
101
- and is expected to exist for the lifetime of the program.
102
- It creates and owns the default logger that calls to the static DefaultLogger method return.
103
- @param default_logger_id Logger Id to use for the default logger. nullptr/ignored if instance_type == Temporal.
104
- @param default_max_vlog_level Default maximum level for VLOG messages to be created unless overridden in CreateLogger.
105
- Requires a severity of kVERBOSE for VLOG messages to be logged.
106
- */
107
- LoggingManager(std::unique_ptr<ISink> sink, Severity default_min_severity, bool default_filter_user_data,
108
- InstanceType instance_type,
109
- const std::string* default_logger_id = nullptr,
110
- int default_max_vlog_level = -1);
111
-
112
- /**
113
- Creates a new logger instance which will use the provided logger_id and default severity and vlog levels.
114
- @param logger_id The log identifier.
115
- @returns A new Logger instance that the caller owns.
116
- */
117
- std::unique_ptr<Logger> CreateLogger(const std::string& logger_id);
118
-
119
- /**
120
- Creates a new logger instance which will use the provided logger_id, severity and vlog levels.
121
- @param logger_id The log identifier.
122
- @param min_severity The minimum severity. Requests to create messages with lower severity will be ignored.
123
- @param filter_user_data If set to true ignore messages with DataType::USER.
124
- @param max_vlog_level Maximum level for VLOG messages to be created.
125
- @returns A new Logger instance that the caller owns.
126
- */
127
- std::unique_ptr<Logger> CreateLogger(const std::string& logger_id,
128
- Severity min_severity, bool filter_user_data, int max_vlog_level = -1);
129
-
130
- /**
131
- Gets the default logger instance if set. Throws if no default logger is currently registered.
132
- @remarks
133
- Creating a LoggingManager instance with is_default_instance == true registers a default logger.
134
- Note that the default logger is only valid until the LoggerManager that registered it is destroyed.
135
- @returns The default logger if available.
136
- */
137
- static const Logger& DefaultLogger();
138
-
139
- /**
140
- Return a boolean indicating if the default logger has been initialized
141
- */
142
- static bool HasDefaultLogger() { return nullptr != s_default_logger_; }
143
-
144
- /**
145
- Change the minimum severity level for log messages to be output by the default logger.
146
- @param severity The severity.
147
- */
148
- static void SetDefaultLoggerSeverity(Severity severity);
149
-
150
- /**
151
- Change the maximum verbosity level for log messages to be output by the default logger.
152
- @remarks
153
- To activate the verbose log, the logger severity must also be set to kVERBOSE.
154
- @param vlog_level The verbosity level.
155
- */
156
- static void SetDefaultLoggerVerbosity(int vlog_level);
157
-
158
- /**
159
- Logs a FATAL level message and creates an exception that can be thrown with error information.
160
- @param category The log category.
161
- @param location The location the log message was generated.
162
- @param format_str The printf format string.
163
- @param ... The printf arguments.
164
- @returns A new Logger instance that the caller owns.
165
- */
166
- static std::exception LogFatalAndCreateException(const char* category,
167
- const CodeLocation& location,
168
- const char* format_str, ...);
169
-
170
- /**
171
- Logs the message using the provided logger id.
172
- @param logger_id The log identifier.
173
- @param message The log message.
174
- */
175
- void Log(const std::string& logger_id, const Capture& message) const;
176
-
177
- /**
178
- Sends a Profiling Event Record to the sink.
179
- @param Profiling Event Record
180
- */
181
- void SendProfileEvent(profiling::EventRecord& eventRecord) const;
182
- ~LoggingManager();
183
-
184
- private:
185
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LoggingManager);
186
-
187
- Timestamp GetTimestamp() const noexcept;
188
- void CreateDefaultLogger(const std::string& logger_id);
189
-
190
- std::unique_ptr<ISink> sink_;
191
- const Severity default_min_severity_;
192
- const bool default_filter_user_data_;
193
- const int default_max_vlog_level_;
194
- bool owns_default_logger_;
195
-
196
- static Logger* s_default_logger_;
197
-
198
- struct Epochs {
199
- const std::chrono::time_point<std::chrono::high_resolution_clock> high_res;
200
- const std::chrono::time_point<std::chrono::system_clock> system;
201
- const std::chrono::minutes localtime_offset_from_utc;
202
- };
203
-
204
- static const Epochs& GetEpochs() noexcept;
205
- };
206
-
207
- /**
208
- Logger provides a per-instance log id. Everything else is passed back up to the LoggingManager
209
- */
210
- class Logger {
211
- public:
212
- /**
213
- Initializes a new instance of the Logger class.
214
- @param loggingManager The logging manager.
215
- @param id The identifier for messages coming from this Logger.
216
- @param severity Minimum severity for messages to be created and logged.
217
- @param filter_user_data Should USER data be filtered from output.
218
- @param vlog_level Minimum level for VLOG messages to be created. Note that a severity of kVERBOSE must be provided
219
- for VLOG messages to be logged.
220
- */
221
- Logger(const LoggingManager& loggingManager, std::string id,
222
- Severity severity, bool filter_user_data, int vlog_level)
223
- : logging_manager_{&loggingManager},
224
- id_{id},
225
- min_severity_{severity},
226
- filter_user_data_{filter_user_data},
227
- max_vlog_level_{vlog_level} {
228
- }
229
-
230
- /**
231
- Get the minimum severity level for log messages to be output.
232
- @returns The severity.
233
- */
234
- Severity GetSeverity() const noexcept { return min_severity_; }
235
-
236
- /**
237
- Change the minimum severity level for log messages to be output.
238
- @param severity The severity.
239
- */
240
- void SetSeverity(Severity severity) noexcept { min_severity_ = severity; }
241
-
242
- /**
243
- Change the maximum verbosity level for log messages to be output.
244
- @remarks
245
- To activate the verbose log, the logger severity must also be set to kVERBOSE.
246
- @param vlog_level The verbosity.
247
- */
248
- void SetVerbosity(int vlog_level) noexcept { max_vlog_level_ = vlog_level; }
249
-
250
- /**
251
- Check if output is enabled for the provided LogSeverity and DataType values.
252
- @param severity The severity.
253
- @param data_type Type of the data.
254
- @returns True if a message with these values will be logged.
255
- */
256
- bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept {
257
- return (severity >= min_severity_ && (data_type != DataType::USER || !filter_user_data_));
258
- }
259
-
260
- /**
261
- Return the maximum VLOG level allowed. Disabled unless logging VLOG messages
262
- */
263
- int VLOGMaxLevel() const noexcept {
264
- return min_severity_ > Severity::kVERBOSE ? -1 : max_vlog_level_;
265
- }
266
-
267
- /**
268
- Logs the captured message.
269
- @param message The log message.
270
- */
271
- void Log(const Capture& message) const {
272
- logging_manager_->Log(id_, message);
273
- }
274
-
275
- /**
276
- Sends a Profiling Event Record to the sink.
277
- @param Profiling Event Record
278
- */
279
- void SendProfileEvent(profiling::EventRecord& eventRecord) const {
280
- logging_manager_->SendProfileEvent(eventRecord);
281
- }
282
-
283
- private:
284
- const LoggingManager* logging_manager_;
285
- const std::string id_;
286
- Severity min_severity_;
287
- const bool filter_user_data_;
288
- int max_vlog_level_;
289
- };
290
-
291
- inline const Logger& LoggingManager::DefaultLogger() {
292
- if (s_default_logger_ == nullptr) {
293
- // fail early for attempted misuse. don't use logging macros as we have no logger.
294
- ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
295
- }
296
-
297
- return *s_default_logger_;
298
- }
299
-
300
- inline void LoggingManager::SetDefaultLoggerSeverity(Severity severity) {
301
- if (s_default_logger_ == nullptr) {
302
- // fail early for attempted misuse. don't use logging macros as we have no logger.
303
- ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
304
- }
305
-
306
- s_default_logger_->SetSeverity(severity);
307
- }
308
-
309
- inline void LoggingManager::SetDefaultLoggerVerbosity(int vlog_level) {
310
- if (s_default_logger_ == nullptr) {
311
- // fail early for attempted misuse. don't use logging macros as we have no logger.
312
- ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
313
- }
314
-
315
- s_default_logger_->SetVerbosity(vlog_level);
316
- }
317
-
318
- inline Timestamp LoggingManager::GetTimestamp() const noexcept {
319
- static const Epochs& epochs = GetEpochs();
320
-
321
- const auto high_res_now = std::chrono::high_resolution_clock::now();
322
- return std::chrono::time_point_cast<std::chrono::system_clock::duration>(
323
- epochs.system + (high_res_now - epochs.high_res) + epochs.localtime_offset_from_utc);
324
- }
325
-
326
- /**
327
- Return the current thread id.
328
- */
329
- unsigned int GetThreadId();
330
-
331
- /**
332
- Return the current process id.
333
- */
334
- unsigned int GetProcessId();
335
-
336
- } // namespace logging
337
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/logging/macros.h DELETED
@@ -1,278 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
- // NOTE: Don't include this file directly. Include logging.h
6
-
7
- #define CREATE_MESSAGE(logger, severity, category, datatype) \
8
- ::onnxruntime::logging::Capture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE)
9
-
10
- /*
11
- Both printf and stream style logging are supported.
12
- Not that printf currently has a 2K limit to the message size.
13
-
14
- LOGS_* macros are for stream style
15
- LOGF_* macros are for printf style
16
-
17
- The Message class captures the log input, and pushes it through the logger in its destructor.
18
-
19
- Use the *FATAL* macros if you want a Severity::kFatal message to also throw.
20
-
21
- There are a few variants to minimize the length of the macro name required in the calling code.
22
- They are optimized so the shortest names are for the (expected) most common usage. This can be
23
- tweaked if needed.
24
-
25
- Explicit logger vs LoggingManager::DefaulLogger()
26
- Default is for a logger instance to be explicitly passed in.
27
- The logger instance provides an identifier so that log messages from different runs can be separated.
28
-
29
- Variants with DEFAULT in the macro name use the default logger provided by logging manager. This is
30
- static so accessible from any code, provided a LoggingManager instance created with InstanceType::Default
31
- exists somewhere. See logging.h for further explanation of the expected setup.
32
-
33
- DataType
34
- Default uses DataType::SYSTEM.
35
-
36
- Variants with USER in the macro name use DataType::USER. This is data that could be PII, and may need to
37
- be filtered from output. LoggingManager applies this filtering.
38
-
39
- Category
40
- Default category is ::onnxruntime::Logging::Category::onnxruntime.
41
-
42
- If you wish to provide a different category, use variants with CATEGORY in the macro name
43
-
44
- */
45
-
46
- /**
47
- * Note:
48
- * The stream style logging macros (something like `LOGS() << message`) are designed to be appended to.
49
- * Normally, we can isolate macro code in a separate scope (e.g., `do {...} while(0)`), but here we need the macro code
50
- * to interact with subsequent code (i.e., the values to log).
51
- *
52
- * When an unisolated conditional is involved, extra care needs to be taken to avoid unexpected parsing behavior.
53
- * For example:
54
- *
55
- * if (enabled)
56
- * Capture().Stream()
57
- *
58
- * is more direct, but
59
- *
60
- * if (!enabled) {
61
- * } else Capture().Stream()
62
- *
63
- * ensures that the `if` does not unintentionally associate with a subsequent `else`.
64
- */
65
-
66
- // Logging with explicit category
67
-
68
- // iostream style logging. Capture log info in Message, and push to the logger in ~Message.
69
- #define LOGS_CATEGORY(logger, severity, category) \
70
- if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
71
- ::onnxruntime::logging::DataType::SYSTEM)) { \
72
- /* do nothing */ \
73
- } else \
74
- CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).Stream()
75
-
76
- #define LOGS_USER_CATEGORY(logger, severity, category) \
77
- if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
78
- ::onnxruntime::logging::DataType::USER)) { \
79
- /* do nothing */ \
80
- } else \
81
- CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).Stream()
82
-
83
- // printf style logging. Capture log info in Message, and push to the logger in ~Message.
84
- #define LOGF_CATEGORY(logger, severity, category, format_str, ...) \
85
- do { \
86
- if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
87
- ::onnxruntime::logging::DataType::SYSTEM)) \
88
- CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM) \
89
- .CapturePrintf(format_str, ##__VA_ARGS__); \
90
- } while (0)
91
-
92
- #define LOGF_USER_CATEGORY(logger, severity, category, format_str, ...) \
93
- do { \
94
- if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
95
- ::onnxruntime::logging::DataType::USER)) \
96
- CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER) \
97
- .CapturePrintf(format_str, ##__VA_ARGS__); \
98
- } while (0)
99
-
100
- // Logging with category of "onnxruntime"
101
-
102
- #define LOGS(logger, severity) \
103
- LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime)
104
-
105
- #define LOGS_USER(logger, severity) \
106
- LOGS_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime)
107
-
108
- // printf style logging. Capture log info in Message, and push to the logger in ~Message.
109
- #define LOGF(logger, severity, format_str, ...) \
110
- LOGF_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
111
-
112
- #define LOGF_USER(logger, severity, format_str, ...) \
113
- LOGF_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
114
-
115
- /*
116
- Macros that use the default logger.
117
- A LoggingManager instance must be currently valid for the default logger to be available.
118
- */
119
-
120
- // Logging with explicit category
121
-
122
- #define LOGS_DEFAULT_CATEGORY(severity, category) \
123
- LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)
124
-
125
- #define LOGS_USER_DEFAULT_CATEGORY(severity, category) \
126
- LOGS_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)
127
-
128
- #define LOGF_DEFAULT_CATEGORY(severity, category, format_str, ...) \
129
- LOGF_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__)
130
-
131
- #define LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ...) \
132
- LOGF_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__)
133
-
134
- // Logging with category of "onnxruntime"
135
-
136
- #define LOGS_DEFAULT(severity) \
137
- LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime)
138
-
139
- #define LOGS_USER_DEFAULT(severity) \
140
- LOGS_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime)
141
-
142
- #define LOGF_DEFAULT(severity, format_str, ...) \
143
- LOGF_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
144
-
145
- #define LOGF_USER_DEFAULT(severity, format_str, ...) \
146
- LOGF_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
147
-
148
- /*
149
- Conditional logging
150
- */
151
-
152
- // Logging with explicit category
153
-
154
- #define LOGS_CATEGORY_IF(boolean_expression, logger, severity, category) \
155
- if (!((boolean_expression) == true)) { \
156
- /* do nothing */ \
157
- } else \
158
- LOGS_CATEGORY(logger, severity, category)
159
-
160
- #define LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \
161
- if (!((boolean_expression) == true)) { \
162
- /* do nothing */ \
163
- } else \
164
- LOGS_DEFAULT_CATEGORY(severity, category)
165
-
166
- #define LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, category) \
167
- if (!((boolean_expression) == true)) { \
168
- /* do nothing */ \
169
- } else \
170
- LOGS_USER_CATEGORY(logger, severity, category)
171
-
172
- #define LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \
173
- if (!((boolean_expression) == true)) { \
174
- /* do nothing */ \
175
- } else \
176
- LOGS_USER_DEFAULT_CATEGORY(severity, category)
177
-
178
- #define LOGF_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \
179
- do { \
180
- if ((boolean_expression) == true) LOGF_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \
181
- } while (0)
182
-
183
- #define LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \
184
- do { \
185
- if ((boolean_expression) == true) LOGF_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \
186
- } while (0)
187
-
188
- #define LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \
189
- do { \
190
- if ((boolean_expression) == true) LOGF_USER_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \
191
- } while (0)
192
-
193
- #define LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \
194
- do { \
195
- if ((boolean_expression) == true) LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \
196
- } while (0)
197
-
198
- // Logging with category of "onnxruntime"
199
-
200
- #define LOGS_IF(boolean_expression, logger, severity) \
201
- LOGS_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime)
202
-
203
- #define LOGS_DEFAULT_IF(boolean_expression, severity) \
204
- LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime)
205
-
206
- #define LOGS_USER_IF(boolean_expression, logger, severity) \
207
- LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime)
208
-
209
- #define LOGS_USER_DEFAULT_IF(boolean_expression, severity) \
210
- LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime)
211
-
212
- #define LOGF_IF(boolean_expression, logger, severity, format_str, ...) \
213
- LOGF_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
214
-
215
- #define LOGF_DEFAULT_IF(boolean_expression, severity, format_str, ...) \
216
- LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
217
-
218
- #define LOGF_USER_IF(boolean_expression, logger, severity, format_str, ...) \
219
- LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, \
220
- format_str, ##__VA_ARGS__)
221
-
222
- #define LOGF_USER_DEFAULT_IF(boolean_expression, severity, format_str, ...) \
223
- LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, \
224
- format_str, ##__VA_ARGS__)
225
-
226
- /*
227
- Debug verbose logging of caller provided level.
228
- Disabled in Release builds.
229
- Use the _USER variants for VLOG statements involving user data that may need to be filtered.
230
- */
231
- #ifndef NDEBUG
232
- #define VLOGS(logger, level) \
233
- if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \
234
- /* do nothing */ \
235
- } else \
236
- LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
237
-
238
- #define VLOGS_USER(logger, level) \
239
- if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \
240
- /* do nothing */ \
241
- } else \
242
- LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
243
-
244
- #define VLOGF(logger, level, format_str, ...) \
245
- do { \
246
- if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \
247
- LOGF_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \
248
- } while (0)
249
-
250
- #define VLOGF_USER(logger, level, format_str, ...) \
251
- do { \
252
- if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \
253
- LOGF_USER_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \
254
- } while (0)
255
- #else
256
- // Disabled in Release builds.
257
- #define VLOGS(logger, level) \
258
- if constexpr (true) {} else LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
259
- #define VLOGS_USER(logger, level) \
260
- if constexpr (true) {} else LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
261
- #define VLOGF(logger, level, format_str, ...)
262
- #define VLOGF_USER(logger, level, format_str, ...)
263
- #endif
264
-
265
-
266
-
267
- // Default logger variants
268
- #define VLOGS_DEFAULT(level) \
269
- VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)
270
-
271
- #define VLOGS_USER_DEFAULT(level) \
272
- VLOGS_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)
273
-
274
- #define VLOGF_DEFAULT(level, format_str, ...) \
275
- VLOGF(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__)
276
-
277
- #define VLOGF_USER_DEFAULT(level, format_str, ...) \
278
- VLOGF_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/logging/severity.h DELETED
@@ -1,22 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- namespace onnxruntime {
7
- namespace logging {
8
- // mild violation of naming convention. the 'k' lets us use token concatenation in the macro
9
- // ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity
10
- // the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR)
11
- enum class Severity {
12
- kVERBOSE = 0,
13
- kINFO = 1,
14
- kWARNING = 2,
15
- kERROR = 3,
16
- kFATAL = 4
17
- };
18
-
19
- constexpr const char* SEVERITY_PREFIX = "VIWEF";
20
-
21
- } // namespace logging
22
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/make_string.h DELETED
@@ -1,126 +0,0 @@
1
- /**
2
- * Copyright (c) 2016-present, Facebook, Inc.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
- // Portions Copyright (c) Microsoft Corporation
17
-
18
- #pragma once
19
-
20
- #include <locale>
21
- #include <sstream>
22
- #include <type_traits>
23
-
24
- namespace onnxruntime {
25
-
26
- namespace detail {
27
-
28
- inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept {
29
- }
30
-
31
- template <typename T>
32
- inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept {
33
- ss << t;
34
- }
35
-
36
- template <typename T, typename... Args>
37
- inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
38
- MakeStringImpl(ss, t);
39
- MakeStringImpl(ss, args...);
40
- }
41
-
42
- // see MakeString comments for explanation of why this is necessary
43
- template <typename... Args>
44
- inline std::string MakeStringImpl(const Args&... args) noexcept {
45
- std::ostringstream ss;
46
- MakeStringImpl(ss, args...);
47
- return ss.str();
48
- }
49
-
50
- //
51
- // Infrastructure to convert char[n] to char* to reduce binary size
52
- //
53
-
54
- // default is to leave the type as is
55
- template <class T>
56
- struct if_char_array_make_ptr {
57
- using type = T;
58
- };
59
-
60
- // specialization that matches an array reference, which is what the char array from a string literal
61
- // used in a call to MakeString will be.
62
- // if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded.
63
- template <class T, size_t N>
64
- struct if_char_array_make_ptr<T (&)[N]> {
65
- // remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x],
66
- // and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched.
67
- using element_type = typename std::remove_const<typename std::remove_extent<T>::type>::type;
68
- using type = typename std::conditional<std::is_same<char, element_type>::value, T*, T (&)[N]>::type;
69
- };
70
-
71
- // helper to make usage simpler in MakeString
72
- template <class T>
73
- using if_char_array_make_ptr_t = typename if_char_array_make_ptr<T>::type;
74
- } // namespace detail
75
-
76
- /**
77
- * Makes a string by concatenating string representations of the arguments.
78
- * This version uses the current locale.
79
- */
80
- template <typename... Args>
81
- std::string MakeString(const Args&... args) {
82
- // We need to update the types from the MakeString template instantiation to decay any char[n] to char*.
83
- // e.g. MakeString("in", "out") goes from MakeString<char[2], char[3]> to MakeStringImpl<char*, char*>
84
- // so that MakeString("out", "in") will also match MakeStringImpl<char*, char*> instead of requiring
85
- // MakeStringImpl<char[3], char[2]>.
86
- //
87
- // We have to do the type processing before any actual work, so this function purely implements the type processing.
88
- // If we do not do it this way we do not get the full binary size reduction.
89
- //
90
- // See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover
91
- // the need to do the type processing as a separate step.
92
-
93
- return detail::MakeStringImpl(detail::if_char_array_make_ptr_t<Args const&>(args)...);
94
- }
95
-
96
- /**
97
- * Makes a string by concatenating string representations of the arguments.
98
- * This version uses std::locale::classic().
99
- */
100
- template <typename... Args>
101
- std::string MakeStringWithClassicLocale(const Args&... args) {
102
- std::ostringstream ss;
103
- ss.imbue(std::locale::classic());
104
- detail::MakeStringImpl(ss, args...);
105
- return ss.str();
106
- }
107
-
108
- // MakeString versions for already-a-string types.
109
-
110
- inline std::string MakeString(const std::string& str) {
111
- return str;
112
- }
113
-
114
- inline std::string MakeString(const char* cstr) {
115
- return cstr;
116
- }
117
-
118
- inline std::string MakeStringWithClassicLocale(const std::string& str) {
119
- return str;
120
- }
121
-
122
- inline std::string MakeStringWithClassicLocale(const char* cstr) {
123
- return cstr;
124
- }
125
-
126
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/narrow.h DELETED
@@ -1,77 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- // onnxruntime::narrow() is like gsl::narrow() but it is also available when exceptions are disabled.
7
-
8
- #if !defined(ORT_NO_EXCEPTIONS)
9
-
10
- #include "gsl/narrow"
11
-
12
- namespace onnxruntime {
13
- using gsl::narrow;
14
- } // namespace onnxruntime
15
-
16
- #else // ^^ !defined(ORT_NO_EXCEPTIONS) ^^ / vv defined(ORT_NO_EXCEPTIONS) vv
17
-
18
- #include <cstdio> // std::fprintf
19
- #include <exception> // std::terminate
20
- #include <type_traits>
21
-
22
- #include "gsl/util" // gsl::narrow_cast
23
-
24
- namespace onnxruntime {
25
-
26
- namespace detail {
27
- [[noreturn]] inline void OnNarrowingError() noexcept {
28
- std::fprintf(stderr, "%s", "narrowing error\n");
29
- std::terminate();
30
- }
31
- } // namespace detail
32
-
33
- // This implementation of onnxruntime::narrow was copied and adapted from:
34
- // https://github.com/microsoft/GSL/blob/a3534567187d2edc428efd3f13466ff75fe5805c/include/gsl/narrow
35
-
36
- // narrow() : a checked version of narrow_cast() that terminates if the cast changed the value
37
- template <class T, class U, typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
38
- // clang-format off
39
- GSL_SUPPRESS(type.1) // NO-FORMAT: attribute
40
- // clang-format on
41
- constexpr T narrow(U u) noexcept {
42
- constexpr const bool is_different_signedness =
43
- (std::is_signed<T>::value != std::is_signed<U>::value);
44
-
45
- // clang-format off
46
- GSL_SUPPRESS(es.103) // NO-FORMAT: attribute // don't overflow
47
- GSL_SUPPRESS(es.104) // NO-FORMAT: attribute // don't underflow
48
- GSL_SUPPRESS(p.2) // NO-FORMAT: attribute // don't rely on undefined behavior
49
- // clang-format on
50
- const T t = gsl::narrow_cast<T>(u); // While this is technically undefined behavior in some cases (i.e., if the source value is of floating-point type
51
- // and cannot fit into the destination integral type), the resultant behavior is benign on the platforms
52
- // that we target (i.e., no hardware trap representations are hit).
53
-
54
- if (static_cast<U>(t) != u || (is_different_signedness && ((t < T{}) != (u < U{})))) {
55
- detail::OnNarrowingError();
56
- }
57
-
58
- return t;
59
- }
60
-
61
- template <class T, class U, typename std::enable_if<!std::is_arithmetic<T>::value>::type* = nullptr>
62
- // clang-format off
63
- GSL_SUPPRESS(type.1) // NO-FORMAT: attribute
64
- // clang-format on
65
- constexpr T narrow(U u) noexcept {
66
- const T t = gsl::narrow_cast<T>(u);
67
-
68
- if (static_cast<U>(t) != u) {
69
- detail::OnNarrowingError();
70
- }
71
-
72
- return t;
73
- }
74
-
75
- } // namespace onnxruntime
76
-
77
- #endif // defined(ORT_NO_EXCEPTIONS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/optional.h DELETED
@@ -1,23 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
- #include <optional>
6
-
7
- namespace onnxruntime {
8
-
9
- using std::optional;
10
-
11
- #ifndef ORT_NO_EXCEPTIONS
12
- using std::bad_optional_access;
13
- #endif
14
-
15
- using std::nullopt;
16
- using std::nullopt_t;
17
-
18
- using std::in_place;
19
- using std::in_place_t;
20
-
21
- using std::make_optional;
22
-
23
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/parse_string.h DELETED
@@ -1,85 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <locale>
7
- #include <sstream>
8
- #include <string_view>
9
- #include <type_traits>
10
-
11
- #include "core/common/common.h"
12
-
13
- namespace onnxruntime {
14
-
15
- /**
16
- * Tries to parse a value from an entire string.
17
- */
18
- template <typename T>
19
- bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
20
- if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
21
- // if T is unsigned integral type, reject negative values which will wrap
22
- if (!str.empty() && str[0] == '-') {
23
- return false;
24
- }
25
- }
26
-
27
- // don't allow leading whitespace
28
- if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
29
- return false;
30
- }
31
-
32
- std::istringstream is{std::string{str}};
33
- is.imbue(std::locale::classic());
34
- T parsed_value{};
35
-
36
- const bool parse_successful =
37
- is >> parsed_value &&
38
- is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
39
- if (!parse_successful) {
40
- return false;
41
- }
42
-
43
- value = std::move(parsed_value);
44
- return true;
45
- }
46
-
47
- inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
48
- value = str;
49
- return true;
50
- }
51
-
52
- inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
53
- if (str == "0" || str == "False" || str == "false") {
54
- value = false;
55
- return true;
56
- }
57
-
58
- if (str == "1" || str == "True" || str == "true") {
59
- value = true;
60
- return true;
61
- }
62
-
63
- return false;
64
- }
65
-
66
- /**
67
- * Parses a value from an entire string.
68
- */
69
- template <typename T>
70
- Status ParseStringWithClassicLocale(std::string_view s, T& value) {
71
- ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
72
- return Status::OK();
73
- }
74
-
75
- /**
76
- * Parses a value from an entire string.
77
- */
78
- template <typename T>
79
- T ParseStringWithClassicLocale(std::string_view s) {
80
- T value{};
81
- ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value));
82
- return value;
83
- }
84
-
85
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/profiler_common.h DELETED
@@ -1,93 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "core/common/common.h"
7
-
8
- #include <string>
9
- #include <unordered_map>
10
-
11
- namespace onnxruntime {
12
- namespace profiling {
13
-
14
- enum EventCategory {
15
- SESSION_EVENT = 0,
16
- NODE_EVENT,
17
- KERNEL_EVENT,
18
- API_EVENT,
19
- EVENT_CATEGORY_MAX
20
- };
21
-
22
- // Event descriptions for the above session events.
23
- static constexpr const char* event_category_names_[EVENT_CATEGORY_MAX] = {
24
- "Session",
25
- "Node",
26
- "Kernel",
27
- "Api"};
28
-
29
- // Timing record for all events.
30
- struct EventRecord {
31
- EventRecord() = default;
32
- EventRecord(EventCategory category,
33
- int process_id,
34
- int thread_id,
35
- std::string&& event_name,
36
- long long time_stamp,
37
- long long duration,
38
- std::unordered_map<std::string, std::string>&& event_args)
39
- : cat(category),
40
- pid(process_id),
41
- tid(thread_id),
42
- name(std::move(event_name)),
43
- ts(time_stamp),
44
- dur(duration),
45
- args(std::move(event_args)) {}
46
-
47
- EventRecord(EventCategory category,
48
- int process_id,
49
- int thread_id,
50
- const std::string& event_name,
51
- long long time_stamp,
52
- long long duration,
53
- const std::unordered_map<std::string, std::string>& event_args)
54
- : cat(category),
55
- pid(process_id),
56
- tid(thread_id),
57
- name(event_name),
58
- ts(time_stamp),
59
- dur(duration),
60
- args(event_args) {}
61
-
62
- EventRecord(const EventRecord& other) = default;
63
- EventRecord(EventRecord&& other) noexcept = default;
64
- EventRecord& operator=(const EventRecord& other) = default;
65
- EventRecord& operator=(EventRecord&& other) = default;
66
-
67
- EventCategory cat = EventCategory::API_EVENT;
68
- int pid = -1;
69
- int tid = -1;
70
- std::string name{};
71
- long long ts = 0;
72
- long long dur = 0;
73
- std::unordered_map<std::string, std::string> args{};
74
- };
75
-
76
- using Events = std::vector<EventRecord>;
77
-
78
- //Execution Provider Profiler
79
- class EpProfiler {
80
- public:
81
- virtual ~EpProfiler() = default;
82
- virtual bool StartProfiling(TimePoint profiling_start_time) = 0; // called when profiling starts
83
- virtual void EndProfiling(TimePoint start_time, Events& events) = 0; // called when profiling ends, save all captures numbers to "events"
84
- virtual void Start(uint64_t){}; // called before op start, accept an id as argument to identify the op
85
- virtual void Stop(uint64_t){}; // called after op stop, accept an id as argument to identify the op
86
- };
87
-
88
- // Demangle C++ symbols
89
- std::string demangle(const char* name);
90
- std::string demangle(const std::string& name);
91
-
92
- } // namespace profiling
93
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/span_utils.h DELETED
@@ -1,88 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <algorithm>
7
-
8
- #include "core/common/gsl.h"
9
-
10
- namespace onnxruntime {
11
-
12
- // AsSpan inspired by Fekir's Blog https://fekir.info/post/span-the-missing-constructor/
13
- // Used under MIT license
14
-
15
- // Use AsSpan for less typing on any container including initializer list to create a span
16
- // (unnamed, untyped initializer list does not automatically convert to gsl::span).
17
- // {1, 2, 3} as such does not have a type
18
- // (see https://scottmeyers.blogspot.com/2014/03/if-braced-initializers-have-no-type-why.html)
19
- //
20
- // Example: AsSpan({1, 2, 3}) results in gsl::span<const int>
21
- //
22
- // The above would deduce to std::initializer_list<int> and the result is gsl::span<const int>
23
- //
24
- // AsSpan<int64_t>({1, 2, 3}) produces gsl::span<const int64_t>
25
- //
26
- // We can also do std::array<int64_t, 3>{1, 2, 3} that can be automatically converted to span
27
- // without memory allocation.
28
- //
29
- // If type conversion is not required, then for C++17 std::array template parameters are
30
- // auto-deduced. Example: std::array{1, 2, 3}.
31
- // We are aiming at not allocating memory dynamically.
32
-
33
- namespace details {
34
- template <class P>
35
- constexpr auto AsSpanImpl(P* p, size_t s) {
36
- return gsl::span<P>(p, s);
37
- }
38
- } // namespace details
39
-
40
- template <class C>
41
- constexpr auto AsSpan(C& c) {
42
- return details::AsSpanImpl(c.data(), c.size());
43
- }
44
-
45
- template <class C>
46
- constexpr auto AsSpan(const C& c) {
47
- return details::AsSpanImpl(c.data(), c.size());
48
- }
49
-
50
- template <class C>
51
- constexpr auto AsSpan(C&& c) {
52
- return details::AsSpanImpl(c.data(), c.size());
53
- }
54
-
55
- template <class T>
56
- constexpr auto AsSpan(std::initializer_list<T> c) {
57
- return details::AsSpanImpl(c.begin(), c.size());
58
- }
59
-
60
- template <class T, size_t N>
61
- constexpr auto AsSpan(T (&arr)[N]) {
62
- return details::AsSpanImpl(arr, N);
63
- }
64
-
65
- template <class T, size_t N>
66
- constexpr auto AsSpan(const T (&arr)[N]) {
67
- return details::AsSpanImpl(arr, N);
68
- }
69
-
70
- template <class T>
71
- inline gsl::span<const T> EmptySpan() { return gsl::span<const T>(); }
72
-
73
- template <class U, class T>
74
- [[nodiscard]] inline gsl::span<U> ReinterpretAsSpan(gsl::span<T> src) {
75
- // adapted from gsl-lite span::as_span():
76
- // https://github.com/gsl-lite/gsl-lite/blob/4720a2980a30da085b4ddb4a0ea2a71af7351a48/include/gsl/gsl-lite.hpp#L4102-L4108
77
- Expects(src.size_bytes() % sizeof(U) == 0);
78
- return gsl::span<U>(reinterpret_cast<U*>(src.data()), src.size_bytes() / sizeof(U));
79
- }
80
-
81
- template <class T1, size_t Extent1, class T2, size_t Extent2>
82
- [[nodiscard]] inline bool SpanEq(gsl::span<T1, Extent1> a, gsl::span<T2, Extent2> b) {
83
- static_assert(std::is_same_v<std::remove_const_t<T1>, std::remove_const_t<T2>>,
84
- "T1 and T2 should be the same type except for const qualification");
85
- return std::equal(a.begin(), a.end(), b.begin(), b.end());
86
- }
87
-
88
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/spin_pause.h DELETED
@@ -1,28 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #if defined(_M_AMD64)
7
- #include <intrin.h>
8
- #endif
9
-
10
- #if defined(__x86_64__)
11
- #include <xmmintrin.h>
12
- #endif
13
-
14
- namespace onnxruntime {
15
-
16
- namespace concurrency {
17
-
18
- // Intrinsic to use in spin-loops
19
-
20
- inline void SpinPause() {
21
- #if defined(_M_AMD64) || defined(__x86_64__)
22
- _mm_pause();
23
- #endif
24
- }
25
-
26
- } // namespace concurrency
27
-
28
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/status.h DELETED
@@ -1,195 +0,0 @@
1
- /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
- Licensed under the Apache License, Version 2.0 (the "License");
3
- you may not use this file except in compliance with the License.
4
- You may obtain a copy of the License at
5
- http://www.apache.org/licenses/LICENSE-2.0
6
- Unless required by applicable law or agreed to in writing, software
7
- distributed under the License is distributed on an "AS IS" BASIS,
8
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
- See the License for the specific language governing permissions and
10
- limitations under the License.
11
- ==============================================================================*/
12
- // Modifications Copyright (c) Microsoft.
13
-
14
- #pragma once
15
-
16
- #include <memory>
17
- #include <ostream>
18
- #include <string>
19
- #ifdef _WIN32
20
- #include <winerror.h>
21
- #endif
22
- #include "core/common/gsl.h"
23
- namespace onnxruntime {
24
- namespace common {
25
-
26
- enum StatusCategory {
27
- NONE = 0,
28
- SYSTEM = 1,
29
- ONNXRUNTIME = 2,
30
- };
31
-
32
- /**
33
- Error code for ONNXRuntime.
34
- */
35
- enum StatusCode {
36
- OK = 0,
37
- FAIL = 1,
38
- INVALID_ARGUMENT = 2,
39
- NO_SUCHFILE = 3,
40
- NO_MODEL = 4,
41
- ENGINE_ERROR = 5,
42
- RUNTIME_EXCEPTION = 6,
43
- INVALID_PROTOBUF = 7,
44
- MODEL_LOADED = 8,
45
- NOT_IMPLEMENTED = 9,
46
- INVALID_GRAPH = 10,
47
- EP_FAIL = 11
48
- };
49
-
50
- constexpr const char* StatusCodeToString(StatusCode status) noexcept {
51
- switch (status) {
52
- case StatusCode::OK:
53
- return "SUCCESS";
54
- case StatusCode::FAIL:
55
- return "FAIL";
56
- case StatusCode::INVALID_ARGUMENT:
57
- return "INVALID_ARGUMENT";
58
- case StatusCode::NO_SUCHFILE:
59
- return "NO_SUCHFILE";
60
- case StatusCode::NO_MODEL:
61
- return "NO_MODEL";
62
- case StatusCode::ENGINE_ERROR:
63
- return "ENGINE_ERROR";
64
- case StatusCode::RUNTIME_EXCEPTION:
65
- return "RUNTIME_EXCEPTION";
66
- case StatusCode::INVALID_PROTOBUF:
67
- return "INVALID_PROTOBUF";
68
- case StatusCode::MODEL_LOADED:
69
- return "MODEL_LOADED";
70
- case StatusCode::NOT_IMPLEMENTED:
71
- return "NOT_IMPLEMENTED";
72
- case StatusCode::INVALID_GRAPH:
73
- return "INVALID_GRAPH";
74
- case StatusCode::EP_FAIL:
75
- return "EP_FAIL";
76
- default:
77
- return "GENERAL ERROR";
78
- }
79
- }
80
-
81
- #ifdef _WIN32
82
- constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
83
- switch (status) {
84
- case StatusCode::OK:
85
- return S_OK;
86
- case StatusCode::FAIL:
87
- return E_FAIL;
88
- case StatusCode::INVALID_ARGUMENT:
89
- return E_INVALIDARG;
90
- case StatusCode::NO_SUCHFILE:
91
- return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
92
- case StatusCode::NO_MODEL:
93
- return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
94
- case StatusCode::ENGINE_ERROR:
95
- return E_FAIL;
96
- case StatusCode::RUNTIME_EXCEPTION:
97
- return E_FAIL;
98
- case StatusCode::INVALID_PROTOBUF:
99
- return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
100
- case StatusCode::MODEL_LOADED:
101
- return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
102
- case StatusCode::NOT_IMPLEMENTED:
103
- return E_NOTIMPL;
104
- case StatusCode::INVALID_GRAPH:
105
- return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
106
- case StatusCode::EP_FAIL:
107
- return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
108
- default:
109
- return E_FAIL;
110
- }
111
- }
112
- #endif
113
-
114
- class [[nodiscard]] Status {
115
- public:
116
- Status() noexcept = default;
117
-
118
- Status(StatusCategory category, int code, const std::string& msg);
119
-
120
- Status(StatusCategory category, int code, const char* msg);
121
-
122
- Status(StatusCategory category, int code);
123
-
124
- GSL_SUPPRESS(r.11)
125
- Status(const Status& other)
126
- : state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {}
127
- GSL_SUPPRESS(r.11)
128
- Status& operator=(const Status& other) {
129
- if (state_ != other.state_) {
130
- if (other.state_ == nullptr) {
131
- state_.reset();
132
- } else {
133
- state_.reset(new State(*other.state_));
134
- }
135
- }
136
- return *this;
137
- }
138
-
139
- Status(Status&&) = default;
140
- Status& operator=(Status&&) = default;
141
- ~Status() = default;
142
-
143
- bool IsOK() const {
144
- return (state_ == nullptr);
145
- }
146
-
147
- int Code() const noexcept;
148
-
149
- StatusCategory Category() const noexcept;
150
-
151
- const std::string& ErrorMessage() const noexcept;
152
-
153
- std::string ToString() const;
154
-
155
- bool operator==(const Status& other) const {
156
- return (this->state_ == other.state_) || (ToString() == other.ToString());
157
- }
158
-
159
- bool operator!=(const Status& other) const {
160
- return !(*this == other);
161
- }
162
-
163
- static Status OK() {
164
- return Status();
165
- }
166
-
167
- private:
168
- static const std::string& EmptyString() noexcept;
169
-
170
- struct State {
171
- State(StatusCategory cat0, int code0, const std::string& msg0)
172
- : category(cat0), code(code0), msg(msg0) {}
173
-
174
- State(StatusCategory cat0, int code0, const char* msg0)
175
- : category(cat0), code(code0), msg(msg0) {}
176
-
177
- const StatusCategory category;
178
- const int code;
179
- const std::string msg;
180
- };
181
-
182
- // As long as Code() is OK, state_ == nullptr.
183
- std::unique_ptr<State> state_;
184
- };
185
-
186
- inline std::ostream& operator<<(std::ostream& out, const Status& status) {
187
- return out << status.ToString();
188
- }
189
-
190
- } // namespace common
191
-
192
- // make Status directly available in the onnxruntime namespace as it is widely used
193
- using common::Status;
194
-
195
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/common/string_helper.h DELETED
@@ -1,11 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
- #include <string>
6
-
7
- // forward declaration
8
- struct OrtAllocator;
9
- namespace onnxruntime {
10
- char* StrDup(const std::string& str, OrtAllocator* allocator);
11
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/alloc_kind.h DELETED
@@ -1,36 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
- #include <iosfwd>
6
-
7
- namespace onnxruntime {
8
- // The ml-Values fall into the following categories with respect to their
9
- // memory management:
10
- // - inference inputs: owned (allocated and freed) by caller, and is by
11
- // default read-only by the runtime.
12
- // - inference outputs: allocated by runtime, ownership transferred to
13
- // caller. TODO: Make sure this semantics is clear in InferenceSession API.
14
- // - weights (constant tensors): can be allocated once (statically), and
15
- // reused by all inference calls within an InferenceSession.
16
- // - tensor values: The lifetimes of these tensor-values are statically
17
- // determined, which is used for memory reuse/sharing optimizations. The
18
- // runtime allocates/frees these values at the right time (as determined
19
- // by the static allocation plan). Note that this is simplified since we
20
- // do not try to optimize for "slice" like ops, where we may be able to
21
- // conditionally reuse memory/data in some cases but not others.
22
- // Generalizing this is future work.
23
-
24
- enum class AllocKind {
25
- kNotSet = -1,
26
- kAllocate = 0,
27
- kReuse = 1,
28
- kPreExisting = 2,
29
- kAllocateStatically = 3,
30
- kAllocateOutput = 4,
31
- kShare = 5,
32
- kAllocatedExternally = 6
33
- };
34
-
35
- std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind);
36
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/allocator.h DELETED
@@ -1,194 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "core/common/common.h"
7
- #include "core/framework/allocator_stats.h"
8
- #include "core/session/onnxruntime_c_api.h"
9
- #include "ortdevice.h"
10
- #include "ortmemoryinfo.h"
11
-
12
- // This configures the arena based allocator used by ORT
13
- // See docs/C_API.md for details on what these mean and how to choose these values
14
- struct OrtArenaCfg {
15
- OrtArenaCfg() : max_mem(0),
16
- arena_extend_strategy(-1),
17
- initial_chunk_size_bytes(-1),
18
- max_dead_bytes_per_chunk(-1),
19
- initial_growth_chunk_size_bytes(-1) {}
20
- OrtArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
21
- int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes)
22
- : max_mem(max_mem),
23
- arena_extend_strategy(arena_extend_strategy),
24
- initial_chunk_size_bytes(initial_chunk_size_bytes),
25
- max_dead_bytes_per_chunk(max_dead_bytes_per_chunk),
26
- initial_growth_chunk_size_bytes(initial_growth_chunk_size_bytes) {}
27
-
28
- size_t max_mem; // use 0 to allow ORT to choose the default
29
- int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
30
- int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
31
- int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
32
- int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
33
- };
34
-
35
- namespace onnxruntime {
36
- constexpr const char* CPU = "Cpu";
37
- constexpr const char* CUDA = "Cuda";
38
- constexpr const char* CUDA_PINNED = "CudaPinned";
39
- constexpr const char* CANN = "Cann";
40
- constexpr const char* CANN_PINNED = "CannPinned";
41
- constexpr const char* DML = "DML";
42
- constexpr const char* HIP = "Hip";
43
- constexpr const char* HIP_PINNED = "HipPinned";
44
- constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
45
- constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
46
-
47
- constexpr size_t kAllocAlignment = 256;
48
-
49
- class IAllocator;
50
- class Stream;
51
- namespace synchronize {
52
- class Notification;
53
- }
54
- using WaitNotificationFn = std::function<void(Stream&, synchronize::Notification&)>;
55
- void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn);
56
-
57
- template <typename T>
58
- using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
59
-
60
- class IAllocator {
61
- public:
62
- IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {}
63
- virtual ~IAllocator() = default;
64
- /**
65
- @remarks Use SafeInt when calculating the size of memory to allocate using Alloc.
66
- */
67
- virtual void* Alloc(size_t size) = 0;
68
-
69
- virtual void Free(void* p) = 0;
70
-
71
- // TODO: Find a better name than Reserve() and update in all places.
72
- // Reserve() is an interface exposed for an implementation of IAllocator
73
- // to optionally implement some allocation logic that by-passes any arena-based
74
- // logic that may be housed in the Alloc() implementation.
75
- // There are SessionOptions config(s) that allow users to allocate some memory
76
- // by-passing arena-based logic.
77
- // By default, the base implementation just calls Alloc().
78
- virtual void* Reserve(size_t size) { return Alloc(size); }
79
-
80
- const OrtMemoryInfo& Info() const { return memory_info_; };
81
-
82
- // Each implementation of IAllocator can override and provide their own implementation
83
- virtual void GetStats(AllocatorStats* /*stats*/) { return; }
84
-
85
- static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept {
86
- return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out);
87
- }
88
-
89
- /**
90
- * Calculate the memory size for an array. The size is bounds checked using SafeInt.
91
- * \tparam alignment must be power of 2
92
- * \param nmemb Number of members or elements in the array
93
- * \param size Size of each element
94
- * \param out Total size required after any alignment is applied
95
- * \return true, successful. false, overflow
96
- */
97
- [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept;
98
-
99
- /**
100
- * https://cwe.mitre.org/data/definitions/190.html
101
- * \param alignment must be power of 2
102
- * \param nmemb Number of members or elements in the array
103
- * \param size Size of each element
104
- * \param out Total size required after any alignment is applied
105
- * \return true, successful. false, overflow
106
- * \remarks This was the original API and was implemented in the header. Replaced with the above version
107
- * implemented in the .cc file so that the SafeInt dependency is internal.
108
- */
109
- template <size_t alignment>
110
- [[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept;
111
-
112
- /**
113
- * allocate memory for an array which has nmemb items of data, each size bytes long
114
- */
115
- void* AllocArray(size_t nmemb, size_t size) {
116
- size_t len;
117
- if (!CalcMemSizeForArray(nmemb, size, &len))
118
- return nullptr;
119
- return Alloc(len);
120
- }
121
-
122
- /**
123
- * allocate memory for an array which has nmemb items of data, each size bytes long
124
- */
125
- template <size_t alignment>
126
- void* AllocArrayWithAlignment(size_t nmemb, size_t size) {
127
- size_t len;
128
- if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len))
129
- return nullptr;
130
- return Alloc(len);
131
- }
132
-
133
- /**
134
- Create a std::unique_ptr that is allocated and freed by the provided IAllocator.
135
- @param allocator The allocator.
136
- @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate.
137
- @param use_reserve If true, call Reserve() instead of Alloc() to allocate memory.
138
- @param stream Which stream instance allocated chunk will be used with.
139
- @param wait_fn If the allocator want to dynamic reuse a chunk from another stream, use this wait_fn to sync on
140
- the target stream to make the reuse safe.
141
- @returns std::unique_ptr with allocated memory and deleter.
142
- */
143
- template <typename T>
144
- static IAllocatorUniquePtr<T> MakeUniquePtr(std::shared_ptr<IAllocator> allocator, size_t count_or_bytes,
145
- bool use_reserve = false,
146
- Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) {
147
- if (allocator == nullptr) return nullptr;
148
- // for now limit to fundamental types. we could support others, but to do so either we or the caller
149
- // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor
150
- // static_assert(std::is_fundamental<T>::value, "Fundamental type required as no destructors are called.");
151
-
152
- size_t alloc_size = count_or_bytes;
153
-
154
- // if T is not void, 'count_or_bytes' == number of items so allow for that
155
- if constexpr (!std::is_void<T>::value) {
156
- // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
157
- // reachable if T is void. use std::conditional to 'use' void* in the sizeof call
158
- if (!CalcMemSizeForArray(
159
- count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type), &alloc_size)) {
160
- return nullptr;
161
- }
162
- }
163
-
164
- // allocate
165
- T* p = static_cast<T*>(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn)));
166
- return IAllocatorUniquePtr<T>{
167
- p,
168
- [allocator = std::move(allocator)](T* p) { allocator->Free(p); }};
169
- }
170
-
171
- private:
172
- OrtMemoryInfo memory_info_;
173
- };
174
-
175
- template <size_t alignment>
176
- bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept {
177
- return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out);
178
- }
179
-
180
- class CPUAllocator : public IAllocator {
181
- public:
182
- explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IAllocator(memory_info) {}
183
-
184
- CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {}
185
-
186
- void* Alloc(size_t size) override;
187
- void Free(void* p) override;
188
- };
189
-
190
- using AllocatorPtr = std::shared_ptr<IAllocator>;
191
-
192
- void* AllocatorDefaultAlloc(size_t size);
193
- void AllocatorDefaultFree(void* p);
194
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/buffer_deleter.h DELETED
@@ -1,36 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "core/framework/allocator.h"
7
-
8
- namespace onnxruntime {
9
-
10
- // TODO: Do we need this class or is IAllocator::MakeUniquePtr sufficient/better
11
- class BufferDeleter {
12
- public:
13
- BufferDeleter() = default;
14
- explicit BufferDeleter(AllocatorPtr alloc)
15
- : alloc_(std::move(alloc)) {}
16
-
17
- void operator()(void* p) const {
18
- if (alloc_)
19
- alloc_->Free(p);
20
- }
21
-
22
- private:
23
- // TODO: we may need consider the lifetime of alloc carefully
24
- // The alloc_ here is the allocator that used to allocate the buffer
25
- // And need go with the unique_ptr together. If it is using our internal
26
- // allocator, it is ok as our allocators are global managed. But if it
27
- // is provide by user, user need to be very careful about it.
28
- // A weak_ptr may be a choice to reduce the impact, but that require to
29
- // change our current allocator mgr to use shared_ptr. Will revisit it
30
- // later.
31
- AllocatorPtr alloc_{nullptr};
32
- };
33
-
34
- using BufferUniquePtr = std::unique_ptr<void, BufferDeleter>;
35
- using BufferNakedPtr = void*;
36
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/customregistry.h DELETED
@@ -1,60 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "core/common/status.h"
7
- #include "core/common/logging/logging.h"
8
- #include "core/framework/op_kernel.h"
9
- #include "core/framework/kernel_def_builder.h"
10
- #include "core/framework/kernel_registry.h"
11
-
12
- #if !defined(ORT_MINIMAL_BUILD)
13
- #include "core/graph/schema_registry.h"
14
- #endif
15
-
16
- namespace onnxruntime {
17
-
18
- /**
19
- Represents a registry that contains both custom kernels and custom schemas.
20
- */
21
- class CustomRegistry final {
22
- public:
23
- CustomRegistry()
24
- : kernel_registry_(std::make_shared<KernelRegistry>())
25
- #if !defined(ORT_MINIMAL_BUILD)
26
- ,
27
- opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>())
28
- #endif
29
- {
30
- }
31
-
32
- /**
33
- * Register a kernel definition together with kernel factory method to this session.
34
- * If any conflict happened between registered kernel def and built-in kernel def,
35
- * registered kernel will have higher priority.
36
- * Call this before invoking Initialize().
37
- * @return OK if success.
38
- */
39
- common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
40
-
41
- common::Status RegisterCustomKernel(KernelCreateInfo&);
42
-
43
- const std::shared_ptr<KernelRegistry>& GetKernelRegistry();
44
-
45
- #if !defined(ORT_MINIMAL_BUILD)
46
- common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
47
- int baseline_opset_version, int opset_version);
48
-
49
- const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& GetOpschemaRegistry();
50
- #endif
51
-
52
- private:
53
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry);
54
- std::shared_ptr<KernelRegistry> kernel_registry_;
55
- #if !defined(ORT_MINIMAL_BUILD)
56
- std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
57
- #endif
58
- };
59
-
60
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/data_types.h DELETED
@@ -1,1062 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <cstdint>
7
- #include <cstring>
8
- #include <string>
9
- #include <type_traits>
10
- #include <map>
11
- #include <unordered_map>
12
- #include "core/common/gsl.h"
13
- #include "core/common/common.h"
14
- #include "core/common/exceptions.h"
15
- #include "core/framework/endian.h"
16
- #include "core/framework/float16.h"
17
- #include "core/framework/to_tensor_proto_element_type.h"
18
- #if !defined(ORT_MINIMAL_BUILD)
19
- #include "onnx/defs/schema.h"
20
- #else
21
- #include "onnx/defs/data_type_utils.h"
22
- #endif
23
- #include "onnx/onnx_pb.h"
24
- #include "onnx/onnx-operators_pb.h"
25
-
26
- struct OrtValue;
27
-
28
- namespace ONNX_NAMESPACE {
29
- class TypeProto;
30
- } // namespace ONNX_NAMESPACE
31
-
32
- namespace onnxruntime {
33
- /// Predefined registered types
34
-
35
- #if !defined(DISABLE_ML_OPS)
36
-
37
- // maps (only used by ML ops)
38
- using MapStringToString = std::map<std::string, std::string>;
39
- using MapStringToInt64 = std::map<std::string, int64_t>;
40
- using MapStringToFloat = std::map<std::string, float>;
41
- using MapStringToDouble = std::map<std::string, double>;
42
- using MapInt64ToString = std::map<int64_t, std::string>;
43
- using MapInt64ToInt64 = std::map<int64_t, int64_t>;
44
- using MapInt64ToFloat = std::map<int64_t, float>;
45
- using MapInt64ToDouble = std::map<int64_t, double>;
46
-
47
- // vectors/sequences
48
- using VectorMapStringToFloat = std::vector<MapStringToFloat>;
49
- using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
50
-
51
- #endif
52
-
53
- using VectorString = std::vector<std::string>;
54
- using VectorInt64 = std::vector<int64_t>;
55
-
56
- // Forward declarations
57
- class DataTypeImpl;
58
- class TensorTypeBase;
59
- #if !defined(DISABLE_SPARSE_TENSORS)
60
- class SparseTensorTypeBase;
61
- #endif
62
- class SequenceTensorTypeBase;
63
- class NonTensorTypeBase;
64
- #if !defined(DISABLE_OPTIONAL_TYPE)
65
- class OptionalTypeBase;
66
- #endif
67
- class PrimitiveDataTypeBase;
68
- class Tensor;
69
- class TensorSeq;
70
-
71
- // DataTypeImpl pointer as unique DataTypeImpl identifier.
72
- using MLDataType = const DataTypeImpl*;
73
- // be used with class MLValue
74
- using DeleteFunc = void (*)(void*);
75
- using CreateFunc = void* (*)();
76
-
77
- /**
78
- * \brief Base class for MLDataType
79
- *
80
- */
81
- class DataTypeImpl {
82
- public:
83
- enum class GeneralType {
84
- kInvalid = 0,
85
- kNonTensor = 1,
86
- kTensor = 2,
87
- kTensorSequence = 3,
88
- kSparseTensor = 4,
89
- kOptional = 5,
90
- kPrimitive = 6,
91
- };
92
-
93
- const GeneralType type_;
94
- const size_t size_;
95
-
96
- protected:
97
- DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {}
98
-
99
- public:
100
- virtual ~DataTypeImpl() = default;
101
-
102
- /**
103
- * \brief this API will be used to check type compatibility at runtime
104
- *
105
- * \param type_proto a TypeProto instance that is constructed for a specific type
106
- * will be checked against a TypeProto instance contained within a corresponding
107
- * MLDataType instance.
108
- */
109
- virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
110
-
111
- size_t Size() const { return size_; }
112
-
113
- virtual DeleteFunc GetDeleteFunc() const = 0;
114
-
115
- /**
116
- * \brief Retrieves an instance of TypeProto for
117
- * a given MLDataType
118
- * \returns optional TypeProto. Only ONNX types
119
- has type proto, non-ONNX types will return nullptr.
120
- */
121
- virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
122
-
123
- bool IsTensorType() const {
124
- return type_ == GeneralType::kTensor;
125
- }
126
-
127
- bool IsTensorSequenceType() const {
128
- return type_ == GeneralType::kTensorSequence;
129
- }
130
-
131
- bool IsSparseTensorType() const {
132
- return type_ == GeneralType::kSparseTensor;
133
- }
134
-
135
- bool IsOptionalType() const {
136
- return type_ == GeneralType::kOptional;
137
- }
138
-
139
- bool IsNonTensorType() const {
140
- return type_ == GeneralType::kNonTensor;
141
- }
142
-
143
- bool IsPrimitiveDataType() const {
144
- return type_ == GeneralType::kPrimitive;
145
- }
146
-
147
- // Returns this if this is of tensor-type and null otherwise
148
- const TensorTypeBase* AsTensorType() const;
149
-
150
- const SequenceTensorTypeBase* AsSequenceTensorType() const;
151
-
152
- #if !defined(DISABLE_SPARSE_TENSORS)
153
- // Returns this if this is of sparse-tensor-type and null otherwise
154
- const SparseTensorTypeBase* AsSparseTensorType() const;
155
- #endif
156
-
157
- #if !defined(DISABLE_OPTIONAL_TYPE)
158
- const OptionalTypeBase* AsOptionalType() const;
159
- #endif
160
-
161
- const NonTensorTypeBase* AsNonTensorType() const;
162
-
163
- // Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase)
164
- // and null otherwise
165
- const PrimitiveDataTypeBase* AsPrimitiveDataType() const;
166
-
167
- // Return the type meta that we are using in the runtime.
168
- template <typename T>
169
- static MLDataType GetType();
170
-
171
- // Return the types for a concrete tensor type, like Tensor_Float
172
- template <typename elemT>
173
- static MLDataType GetTensorType();
174
-
175
- template <typename elemT>
176
- static MLDataType GetSequenceTensorType();
177
-
178
- #if !defined(DISABLE_SPARSE_TENSORS)
179
- // Return the MLDataType for a concrete sparse tensor type.
180
- template <typename elemT>
181
- static MLDataType GetSparseTensorType();
182
- #endif
183
-
184
- template <typename T, typename elemT>
185
- static MLDataType GetOptionalType();
186
-
187
- /**
188
- * Convert an ONNX TypeProto to onnxruntime DataTypeImpl.
189
- * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back.
190
- * Even though GetTypeProto() will not have the original information, it will still have enough to correctly
191
- * map to MLDataType.
192
- * \param proto
193
- */
194
- static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto);
195
-
196
- static const TensorTypeBase* TensorTypeFromONNXEnum(int type);
197
- static const SequenceTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type);
198
- #if !defined(DISABLE_SPARSE_TENSORS)
199
- static const SparseTensorTypeBase* SparseTensorTypeFromONNXEnum(int type);
200
- #endif
201
-
202
- static const char* ToString(MLDataType type);
203
- static std::vector<std::string> ToString(const std::vector<MLDataType>& types);
204
- // Registers ONNX_NAMESPACE::DataType (internalized string) with
205
- // MLDataType. DataType is produced by internalizing an instance of
206
- // TypeProto contained within MLDataType
207
- static void RegisterDataType(MLDataType);
208
- static MLDataType GetDataType(const std::string&);
209
-
210
- static const std::vector<MLDataType>& AllTensorTypes();
211
- static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
212
- static const std::vector<MLDataType>& AllSequenceTensorTypes();
213
- static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes();
214
- static const std::vector<MLDataType>& AllNumericTensorTypes();
215
- static const std::vector<MLDataType>& AllIEEEFloatTensorTypes();
216
- static const std::vector<MLDataType>& AllFixedSizeTensorExceptHalfTypes();
217
- static const std::vector<MLDataType>& AllIEEEFloatTensorExceptHalfTypes();
218
- static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes();
219
- static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes();
220
- static const std::vector<MLDataType>& AllOptionalTypes();
221
- static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypes();
222
- };
223
-
224
- std::ostream& operator<<(std::ostream& out, MLDataType data_type);
225
-
226
- /*
227
- * Type registration helpers
228
- */
229
- namespace data_types_internal {
230
- /// TensorType helpers
231
- ///
232
-
233
- /// Is a given type on the list of types?
234
- /// Accepts a list of types and the first argument is the type
235
- /// We are checking if it is listed among those that follow
236
- template <typename T, typename... Types>
237
- struct IsAnyOf;
238
-
239
- /// Two types remaining, end of the list
240
- template <typename T, typename Tail>
241
- struct IsAnyOf<T, Tail> : public std::is_same<T, Tail> {
242
- };
243
-
244
- template <typename T, typename H, typename... Tail>
245
- struct IsAnyOf<T, H, Tail...> {
246
- static constexpr bool value = (std::is_same<T, H>::value ||
247
- IsAnyOf<T, Tail...>::value);
248
- };
249
-
250
- /// Tells if the specified type is one of fundamental types
251
- /// that can be contained within a tensor.
252
- /// We do not have raw fundamental types, rather a subset
253
- /// of fundamental types is contained within tensors.
254
- template <typename T>
255
- struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
256
- int32_t, int64_t, std::string, bool, MLFloat16,
257
- double, uint32_t, uint64_t, BFloat16> {
258
- };
259
-
260
- #if !defined(DISABLE_SPARSE_TENSORS)
261
- /// Use "IsSparseTensorContainedType<T>::value" to test if a type T
262
- /// is permitted as the element-type of a sparse-tensor.
263
-
264
- template <typename T>
265
- struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
266
- int32_t, int64_t, std::string, bool, MLFloat16,
267
- double, uint32_t, uint64_t, BFloat16> {
268
- };
269
- #endif
270
-
271
- #if !defined(DISABLE_OPTIONAL_TYPE)
272
- /// Tells if the specified type is one of ORT types
273
- /// that can be contained within an optional struct.
274
- template <typename T>
275
- struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
276
- };
277
- #endif
278
-
279
- /// This template's Get() returns a corresponding MLDataType
280
- /// It dispatches the call to either GetTensorType<>() or
281
- /// GetType<>()
282
- template <typename T, bool TensorContainedType>
283
- struct GetMLDataType;
284
-
285
- template <typename T>
286
- struct GetMLDataType<T, true> {
287
- static MLDataType Get() {
288
- return DataTypeImpl::GetTensorType<T>();
289
- }
290
- };
291
-
292
- template <typename T>
293
- struct GetMLDataType<T, false> {
294
- static MLDataType Get() {
295
- return DataTypeImpl::GetType<T>();
296
- }
297
- };
298
-
299
- struct TensorTypeHelper {
300
- static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
301
- ONNX_NAMESPACE::TypeProto& proto) {
302
- proto.mutable_tensor_type()->set_elem_type(element_type);
303
- }
304
- };
305
-
306
- #if !defined(DISABLE_SPARSE_TENSORS)
307
- struct SparseTensorTypeHelper {
308
- static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
309
- ONNX_NAMESPACE::TypeProto& proto) {
310
- proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
311
- }
312
- };
313
- #endif // !defined(DISABLE_SPARSE_TENSORS)
314
-
315
- #if !defined(DISABLE_ML_OPS)
316
- /// Map helpers
317
-
318
- void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&,
319
- ONNX_NAMESPACE::TypeProto&);
320
-
321
- struct MapTypeHelper {
322
- // V can be either a primitive type (in which case it is a tensor)
323
- // or other preregistered types
324
- template <typename V>
325
- static MLDataType GetValueType() {
326
- return GetMLDataType<V, IsTensorContainedType<V>::value>::Get();
327
- }
328
-
329
- static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto,
330
- ONNX_NAMESPACE::TypeProto& proto) {
331
- ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type");
332
- proto.mutable_map_type()->set_key_type(key_type);
333
- CopyMutableMapValue(*value_proto, proto);
334
- }
335
- };
336
- #endif
337
-
338
- /// Sequence helpers
339
-
340
- // Element type is a primitive type so we set it to a tensor<elemT>
341
- void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
342
- ONNX_NAMESPACE::TypeProto&);
343
-
344
- // helper to create TypeProto with minimal binary size impact
345
- struct SequenceTypeHelper {
346
- template <typename T>
347
- static MLDataType GetElemType() {
348
- return GetMLDataType<T, IsTensorContainedType<T>::value>::Get();
349
- }
350
-
351
- static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto,
352
- ONNX_NAMESPACE::TypeProto& proto) {
353
- ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
354
- CopyMutableSeqElement(*elem_proto, proto);
355
- }
356
- };
357
-
358
- /// Optional helpers
359
-
360
- void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
361
- ONNX_NAMESPACE::TypeProto&);
362
-
363
- // helper to create TypeProto with minimal binary size impact
364
- struct OptionalTypeHelper {
365
- template <typename T, typename elemT>
366
- static MLDataType GetElemType() {
367
- if constexpr (std::is_same<T, Tensor>::value) {
368
- return DataTypeImpl::GetTensorType<elemT>();
369
- } else {
370
- static_assert(std::is_same<T, TensorSeq>::value, "Unsupported element type for optional type");
371
- return DataTypeImpl::GetSequenceTensorType<elemT>();
372
- }
373
- }
374
-
375
- static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
376
- ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
377
- CopyMutableOptionalElement(*elem_proto, proto);
378
- }
379
- };
380
-
381
- /// OpaqueTypes helpers
382
-
383
- void AssignOpaqueDomainName(const char* domain, const char* name,
384
- ONNX_NAMESPACE::TypeProto& proto);
385
-
386
- } // namespace data_types_internal
387
-
388
- //The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
389
- //However, we do not allocate this type on heap.
390
- #if defined(_MSC_VER) && !defined(__clang__)
391
- #pragma warning(push)
392
- #pragma warning(disable : 26436)
393
- #endif
394
- /// All tensors base
395
- class TensorTypeBase : public DataTypeImpl {
396
- public:
397
- static MLDataType Type();
398
-
399
- /// We first compare type_proto pointers and then
400
- /// if they do not match try to account for the case
401
- /// where TypeProto was created ad-hoc and not queried from MLDataType
402
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
403
-
404
- DeleteFunc GetDeleteFunc() const override;
405
-
406
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
407
-
408
- virtual MLDataType GetElementType() const {
409
- // should never reach here.
410
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
411
- }
412
-
413
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase);
414
-
415
- protected:
416
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
417
-
418
- TensorTypeBase();
419
- ~TensorTypeBase() override;
420
-
421
- private:
422
- struct Impl;
423
- Impl* impl_;
424
- };
425
-
426
- /**
427
- * \brief Tensor type. This type does not have a C++ type associated with
428
- * it at registration time except the element type. One of the types mentioned
429
- * above at IsTensorContainedType<> list is acceptable.
430
- *
431
- * \details
432
- * Usage:
433
- * ORT_REGISTER_TENSOR(ELEMENT_TYPE)
434
- * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor<type>
435
- * type. IsCompatible() currently ignores shape.
436
- */
437
-
438
- template <typename elemT>
439
- class TensorType : public TensorTypeBase {
440
- public:
441
- static_assert(data_types_internal::IsTensorContainedType<elemT>::value,
442
- "Requires one of the tensor fundamental types");
443
-
444
- static MLDataType Type();
445
-
446
- /// Tensors only can contain basic data types
447
- /// that have been previously registered with ONNXRuntime
448
- MLDataType GetElementType() const override {
449
- return DataTypeImpl::GetType<elemT>();
450
- }
451
-
452
- private:
453
- TensorType() {
454
- using namespace data_types_internal;
455
- TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
456
- }
457
- };
458
-
459
- #if defined(DISABLE_OPTIONAL_TYPE)
460
-
461
- // TODO is this still needed after removing kernel def hashes?
462
- /// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
463
- /// with disabled types to keep the ORT format model kernel hashes stable.
464
- class DisabledTypeBase : public DataTypeImpl {
465
- public:
466
- static MLDataType Type();
467
-
468
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
469
- // We always want to return false for the IsCompatible() for a disabled type
470
- // because this will ensure that no kernel supporting the disabled type will
471
- // be matched to a model node requiring that type and the model load will
472
- // result in failure.
473
- return false;
474
- }
475
-
476
- DeleteFunc GetDeleteFunc() const override {
477
- ORT_THROW("Type is disabled in this build.");
478
- }
479
-
480
- // This must work
481
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
482
-
483
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);
484
-
485
- protected:
486
- // This must work
487
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
488
-
489
- DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
490
- ~DisabledTypeBase() override;
491
-
492
- private:
493
- struct Impl;
494
- Impl* impl_;
495
- };
496
-
497
- #endif
498
-
499
- #if !defined(DISABLE_SPARSE_TENSORS)
500
- /// Common base-class for all sparse-tensors (with different element types).
501
- class SparseTensorTypeBase : public DataTypeImpl {
502
- public:
503
- static MLDataType Type();
504
-
505
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
506
-
507
- DeleteFunc GetDeleteFunc() const override;
508
-
509
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
510
-
511
- virtual MLDataType GetElementType() const {
512
- // should never reach here.
513
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
514
- }
515
-
516
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase);
517
-
518
- protected:
519
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
520
-
521
- SparseTensorTypeBase();
522
- ~SparseTensorTypeBase() override;
523
-
524
- private:
525
- struct Impl;
526
- Impl* impl_;
527
- };
528
-
529
- template <typename elemT>
530
- class SparseTensorType : public SparseTensorTypeBase {
531
- public:
532
- static_assert(data_types_internal::IsSparseTensorContainedType<elemT>::value,
533
- "Requires one of the sparse-tensor fundamental types");
534
-
535
- static MLDataType Type();
536
-
537
- /// Return a MLDataType representing the element-type
538
- MLDataType GetElementType() const override {
539
- return DataTypeImpl::GetType<elemT>();
540
- }
541
-
542
- private:
543
- SparseTensorType() {
544
- using namespace data_types_internal;
545
- SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
546
- }
547
- };
548
-
549
- #endif // !defined(DISABLE_SPARSE_TENSORS)
550
-
551
- /// Common base-class for all optional types.
552
-
553
- #if !defined(DISABLE_OPTIONAL_TYPE)
554
- class OptionalTypeBase : public DataTypeImpl {
555
- public:
556
- static MLDataType Type();
557
-
558
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
559
-
560
- DeleteFunc GetDeleteFunc() const override {
561
- // should never reach here.
562
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
563
- }
564
-
565
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
566
-
567
- virtual MLDataType GetElementType() const {
568
- // should never reach here.
569
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
570
- }
571
-
572
- OptionalTypeBase(const OptionalTypeBase&) = delete;
573
- OptionalTypeBase& operator=(const OptionalTypeBase&) = delete;
574
-
575
- protected:
576
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
577
-
578
- OptionalTypeBase();
579
- ~OptionalTypeBase() override;
580
-
581
- private:
582
- struct Impl;
583
- Impl* impl_;
584
- };
585
- #endif
586
-
587
- // Derive from OptionalTypeBase if the Optional type support is enabled,
588
- // else derive from DisabledTypeBase
589
- template <typename T, typename elemT>
590
- class OptionalType :
591
- #if !defined(DISABLE_OPTIONAL_TYPE)
592
- public OptionalTypeBase
593
- #else
594
- public DisabledTypeBase
595
- #endif
596
- {
597
- public:
598
- static MLDataType Type();
599
-
600
- #if !defined(DISABLE_OPTIONAL_TYPE)
601
- static_assert(data_types_internal::IsOptionalOrtType<T>::value,
602
- "Requires one of the supported types: Tensor or TensorSeq");
603
-
604
- static_assert(data_types_internal::IsTensorContainedType<elemT>::value,
605
- "Requires one of the tensor fundamental types");
606
-
607
- MLDataType GetElementType() const override {
608
- return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
609
- }
610
- #endif
611
-
612
- private:
613
- #if !defined(DISABLE_OPTIONAL_TYPE)
614
- OptionalType()
615
- #else
616
- OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
617
- #endif
618
- {
619
- using namespace data_types_internal;
620
- OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>()->GetTypeProto(), MutableTypeProto());
621
- }
622
- }; // namespace onnxruntime
623
-
624
- /**
625
- * \brief Provide a specialization for your C++ Non-tensor type
626
- * so your implementation FromDataTypeContainer/ToDataTypeContainer
627
- * functions correctly. Otherwise you get a default implementation
628
- * which may not be what you need/want.
629
- *
630
- * This class is used to create OrtValue, fetch data from OrtValue via
631
- * C/C++ APIs
632
- */
633
- template <class T>
634
- struct NonTensorTypeConverter {
635
- static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
636
- ORT_THROW("Not implemented");
637
- }
638
- static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) {
639
- ORT_THROW("Not implemented");
640
- }
641
- };
642
-
643
- /**
644
- * \brief Base type for all non-tensors, maps, sequences and opaques
645
- */
646
- class NonTensorTypeBase : public DataTypeImpl {
647
- public:
648
- DeleteFunc GetDeleteFunc() const override = 0;
649
-
650
- virtual CreateFunc GetCreateFunc() const = 0;
651
-
652
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
653
-
654
- // \brief Override for Non-tensor types to initialize non-tensor CPP
655
- // data representation from data. The caller of the interface
656
- // should have a shared definition of the data which is used to initialize
657
- // CPP data representation. This is used from C API.
658
- //
659
- // \param data - pointer to a data container structure non_tensor type specific
660
- // \param data_size - size of the data container structure, used for rudimentary checks
661
- // \param output - reference to a default constructed non-tensor type
662
- // \returns OrtValue
663
- // \throw if there is an error
664
- virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const;
665
-
666
- // \brief Override for Non-tensor types to fetch data from the internal CPP data representation
667
- // The caller of the interface should have a shared definition of the data which is used to initialize
668
- // CPP data representation. This is used from C API.
669
- //
670
- // \param input - OrtValue containing data
671
- // \param data_size - size of the structure that is being passed for receiving data, used for
672
- // validation
673
- // \param data - pointer to receiving data structure
674
- virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const;
675
-
676
- NonTensorTypeBase(const NonTensorTypeBase&) = delete;
677
- NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete;
678
-
679
- protected:
680
- NonTensorTypeBase(size_t size);
681
- ~NonTensorTypeBase() override;
682
-
683
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
684
-
685
- bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
686
-
687
- bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
688
-
689
- bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
690
-
691
- private:
692
- struct Impl;
693
- Impl* impl_;
694
- };
695
-
696
- // This is where T is the actual CPPRuntimeType
697
- template <typename T>
698
- class NonTensorType : public NonTensorTypeBase {
699
- private:
700
- static void Delete(void* p) {
701
- delete static_cast<T*>(p);
702
- }
703
-
704
- public:
705
- DeleteFunc GetDeleteFunc() const override {
706
- return &Delete;
707
- }
708
-
709
- CreateFunc GetCreateFunc() const override {
710
- return []() -> void* { return new T(); };
711
- }
712
-
713
- protected:
714
- NonTensorType() : NonTensorTypeBase(sizeof(T)) {}
715
- };
716
-
717
- #if !defined(DISABLE_ML_OPS)
718
- /**
719
- * \brief MapType. Use this type to register
720
- * mapping types.
721
- *
722
- * \param T - cpp type that you wish to register as runtime MapType
723
- *
724
- * \details Usage: ORT_REGISTER_MAP(C++Type)
725
- * The type is required to have mapped_type and
726
- * key_type defined
727
- */
728
- template <typename CPPType>
729
- class MapType : public NonTensorType<CPPType> {
730
- public:
731
- static_assert(data_types_internal::IsTensorContainedType<typename CPPType::key_type>::value,
732
- "Requires one of the tensor fundamental types as key");
733
-
734
- static MLDataType Type();
735
-
736
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
737
- return this->IsMapCompatible(type_proto);
738
- }
739
-
740
- private:
741
- MapType() {
742
- using namespace data_types_internal;
743
- MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
744
- MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->GetTypeProto(),
745
- this->MutableTypeProto());
746
- }
747
- };
748
- #endif
749
-
750
- /**
751
- * \brief SequenceType. Use to register sequence for non-tensor types.
752
- *
753
- * \param T - CPP type that you wish to register as Sequence
754
- * runtime type.
755
- *
756
- * \details Usage: ORT_REGISTER_SEQ(C++Type)
757
- * The type is required to have value_type defined
758
- */
759
- template <typename CPPType>
760
- class SequenceType : public NonTensorType<CPPType> {
761
- public:
762
- static MLDataType Type();
763
-
764
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
765
- return this->IsSequenceCompatible(type_proto);
766
- }
767
-
768
- private:
769
- SequenceType() {
770
- using namespace data_types_internal;
771
- SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->GetTypeProto(),
772
- this->MutableTypeProto());
773
- }
774
- };
775
-
776
- /**
777
- * \brief SequenceTensorTypeBase serves as a base type class for
778
- * Tensor sequences. Akin to TensorTypeBase.
779
- * Runtime representation is always TensorSeq.
780
- */
781
- class SequenceTensorTypeBase : public DataTypeImpl {
782
- public:
783
- static MLDataType Type();
784
-
785
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
786
-
787
- virtual MLDataType GetElementType() const {
788
- // should never reach here.
789
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
790
- }
791
-
792
- DeleteFunc GetDeleteFunc() const override;
793
-
794
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
795
-
796
- SequenceTensorTypeBase(const SequenceTensorTypeBase&) = delete;
797
- SequenceTensorTypeBase& operator=(const SequenceTensorTypeBase&) = delete;
798
-
799
- protected:
800
- SequenceTensorTypeBase();
801
- ~SequenceTensorTypeBase();
802
-
803
- ONNX_NAMESPACE::TypeProto& MutableTypeProto();
804
-
805
- private:
806
- struct Impl;
807
- Impl* impl_;
808
- };
809
- #if defined(_MSC_VER) && !defined(__clang__)
810
- #pragma warning(pop)
811
- #endif
812
- /**
813
- * \brief SequenceTensorType. Use to register sequence for non-tensor types.
814
- *
815
- * \param CPPRuntime - We always use TensorSeq
816
- *
817
- * \param TensorElemType - one of the primitive types
818
- *
819
- * \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE()
820
- * The type is required to have value_type defined
821
- */
822
- template <typename TensorElemType>
823
- class SequenceTensorType : public SequenceTensorTypeBase {
824
- public:
825
- static_assert(data_types_internal::IsTensorContainedType<TensorElemType>::value,
826
- "Requires one of the tensor fundamental types");
827
-
828
- static MLDataType Type();
829
-
830
- /// Return a MLDataType representing the element-type
831
- MLDataType GetElementType() const override {
832
- return DataTypeImpl::GetType<TensorElemType>();
833
- }
834
-
835
- private:
836
- SequenceTensorType() {
837
- using namespace data_types_internal;
838
- SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->GetTypeProto(),
839
- MutableTypeProto());
840
- }
841
- };
842
-
843
- /**
844
- * \brief OpaqueType
845
- *
846
- * \tparam T - cpp runtume that implements the Opaque type
847
- *
848
- * \tparam const char D[] - domain must be extern to be unique
849
- *
850
- * \tparam const char N[] - name must be extern to be unique
851
- *
852
- * \details Only one CPP type can be associated with a particular
853
- * OpaqueType registration
854
- *
855
- */
856
- template <typename T, const char D[], const char N[]>
857
- class OpaqueType : public NonTensorType<T> {
858
- public:
859
- static MLDataType Type();
860
-
861
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
862
- return this->IsOpaqueCompatible(type_proto);
863
- }
864
-
865
- void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override {
866
- NonTensorTypeConverter<T>::FromContainer(this, data, data_size, output);
867
- }
868
-
869
- void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override {
870
- NonTensorTypeConverter<T>::ToContainer(input, data_size, data);
871
- }
872
-
873
- private:
874
- OpaqueType() {
875
- data_types_internal::AssignOpaqueDomainName(D, N, this->MutableTypeProto());
876
- }
877
- };
878
-
879
- /**
880
- * \brief PrimitiveDataTypeBase
881
- * Base class for primitive Tensor contained types
882
- *
883
- * \details This class contains an integer constant that can be
884
- * used for input data type dispatching
885
- *
886
- */
887
- class PrimitiveDataTypeBase : public DataTypeImpl {
888
- public:
889
- bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
890
- return false;
891
- }
892
-
893
- const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
894
- return nullptr;
895
- }
896
-
897
- int32_t GetDataType() const {
898
- return data_type_;
899
- }
900
-
901
- protected:
902
- PrimitiveDataTypeBase(size_t size, int32_t data_type)
903
- : DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
904
-
905
- private:
906
- const int32_t data_type_;
907
- };
908
-
909
- /**
910
- * \brief PrimitiveDataType
911
- * Typed specialization for primitive types.
912
- * Concrete instances of this class are used by Tensor.
913
- *
914
- * \param T - primitive data type
915
- *
916
- */
917
- template <typename T>
918
- class PrimitiveDataType : public PrimitiveDataTypeBase {
919
- private:
920
- static void Delete(void* p) {
921
- delete static_cast<T*>(p);
922
- }
923
-
924
- public:
925
- static MLDataType Type();
926
-
927
- DeleteFunc GetDeleteFunc() const override {
928
- return &Delete;
929
- }
930
-
931
- private:
932
- PrimitiveDataType()
933
- : PrimitiveDataTypeBase{sizeof(T),
934
- utils::ToTensorProtoElementType<T>()} {
935
- }
936
- };
937
-
938
- inline const TensorTypeBase* DataTypeImpl::AsTensorType() const {
939
- return IsTensorType() ? static_cast<const TensorTypeBase*>(this) : nullptr;
940
- }
941
-
942
- inline const SequenceTensorTypeBase* DataTypeImpl::AsSequenceTensorType() const {
943
- return IsTensorSequenceType() ? static_cast<const SequenceTensorTypeBase*>(this) : nullptr;
944
- }
945
-
946
- #if !defined(DISABLE_SPARSE_TENSORS)
947
- inline const SparseTensorTypeBase* DataTypeImpl::AsSparseTensorType() const {
948
- return IsSparseTensorType() ? static_cast<const SparseTensorTypeBase*>(this) : nullptr;
949
- }
950
- #endif
951
-
952
- #if !defined(DISABLE_OPTIONAL_TYPE)
953
- inline const OptionalTypeBase* DataTypeImpl::AsOptionalType() const {
954
- return IsOptionalType() ? static_cast<const OptionalTypeBase*>(this) : nullptr;
955
- }
956
- #endif
957
-
958
- inline const NonTensorTypeBase* DataTypeImpl::AsNonTensorType() const {
959
- return IsNonTensorType() ? static_cast<const NonTensorTypeBase*>(this) : nullptr;
960
- }
961
-
962
- inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
963
- return IsPrimitiveDataType() ? static_cast<const PrimitiveDataTypeBase*>(this) : nullptr;
964
- }
965
-
966
- // Explicit specialization of base class template function
967
- // is only possible within the enclosing namespace scope,
968
- // thus a simple way to pre-instantiate a given template
969
- // at a registration time does not currently work and the macro
970
- // is needed.
971
- #define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
972
- template <> \
973
- MLDataType TensorType<ELEM_TYPE>::Type() { \
974
- static TensorType<ELEM_TYPE> tensor_type; \
975
- return &tensor_type; \
976
- } \
977
- template <> \
978
- MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
979
- return TensorType<ELEM_TYPE>::Type(); \
980
- }
981
-
982
- #if !defined(DISABLE_SPARSE_TENSORS)
983
- #define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
984
- template <> \
985
- MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
986
- static SparseTensorType<ELEM_TYPE> tensor_type; \
987
- return &tensor_type; \
988
- } \
989
- template <> \
990
- MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
991
- return SparseTensorType<ELEM_TYPE>::Type(); \
992
- }
993
- #endif
994
-
995
- #define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
996
- template <> \
997
- MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
998
- static OptionalType<ORT_TYPE, TYPE> optional_type; \
999
- return &optional_type; \
1000
- } \
1001
- template <> \
1002
- MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
1003
- return OptionalType<ORT_TYPE, TYPE>::Type(); \
1004
- }
1005
-
1006
- #if !defined(DISABLE_ML_OPS)
1007
- #define ORT_REGISTER_MAP(TYPE) \
1008
- template <> \
1009
- MLDataType MapType<TYPE>::Type() { \
1010
- static MapType<TYPE> map_type; \
1011
- return &map_type; \
1012
- } \
1013
- template <> \
1014
- MLDataType DataTypeImpl::GetType<TYPE>() { \
1015
- return MapType<TYPE>::Type(); \
1016
- }
1017
- #endif
1018
-
1019
- #define ORT_REGISTER_SEQ(TYPE) \
1020
- template <> \
1021
- MLDataType SequenceType<TYPE>::Type() { \
1022
- static SequenceType<TYPE> sequence_type; \
1023
- return &sequence_type; \
1024
- } \
1025
- template <> \
1026
- MLDataType DataTypeImpl::GetType<TYPE>() { \
1027
- return SequenceType<TYPE>::Type(); \
1028
- }
1029
-
1030
- #define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
1031
- template <> \
1032
- MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
1033
- static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
1034
- return &sequence_tensor_type; \
1035
- } \
1036
- template <> \
1037
- MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
1038
- return SequenceTensorType<ELEM_TYPE>::Type(); \
1039
- }
1040
-
1041
- #define ORT_REGISTER_PRIM_TYPE(TYPE) \
1042
- template <> \
1043
- MLDataType PrimitiveDataType<TYPE>::Type() { \
1044
- static PrimitiveDataType<TYPE> prim_data_type; \
1045
- return &prim_data_type; \
1046
- } \
1047
- template <> \
1048
- MLDataType DataTypeImpl::GetType<TYPE>() { \
1049
- return PrimitiveDataType<TYPE>::Type(); \
1050
- }
1051
-
1052
- #define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
1053
- template <> \
1054
- MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
1055
- static OpaqueType<CPPType, Domain, Name> opaque_type; \
1056
- return &opaque_type; \
1057
- } \
1058
- template <> \
1059
- MLDataType DataTypeImpl::GetType<CPPType>() { \
1060
- return OpaqueType<CPPType, Domain, Name>::Type(); \
1061
- }
1062
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/data_types_internal.h DELETED
@@ -1,569 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <array>
7
- #include <cassert>
8
- #include <cstdint>
9
- #include <string>
10
- #include <type_traits>
11
- #include <vector>
12
-
13
- #include "boost/mp11.hpp"
14
-
15
- #include "core/common/common.h"
16
- #include "core/framework/to_tensor_proto_element_type.h"
17
- #ifndef SHARED_PROVIDER
18
- #include "core/common/type_list.h"
19
- #include "core/framework/data_types.h"
20
- #if !defined(ORT_MINIMAL_BUILD)
21
- #include "onnx/defs/schema.h"
22
- #else
23
- #include "onnx/defs/data_type_utils.h"
24
- #endif
25
- #include "onnx/onnx_pb.h"
26
- #include "onnx/onnx-operators_pb.h"
27
- #endif
28
-
29
- namespace onnxruntime {
30
- namespace utils {
31
-
32
- // The following primitives are strongly recommended for switching on tensor input datatypes for
33
- // kernel implementations.
34
- //
35
- // 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
36
- // DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
37
- // 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
38
- // if you have a standalone MLDatatType with a sequence of if/else statements.
39
- // 3) For something in between, we suggest to use CallDispatcher pattern.
40
- //
41
- // Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
42
- // Every primitive type carries with it an integer constant that can be used for quick switching on types.
43
-
44
- #define DispatchOnTensorType(tensor_type, function, ...) \
45
- switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
46
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
47
- function<float>(__VA_ARGS__); \
48
- break; \
49
- case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
50
- function<bool>(__VA_ARGS__); \
51
- break; \
52
- case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
53
- function<double>(__VA_ARGS__); \
54
- break; \
55
- case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
56
- function<std::string>(__VA_ARGS__); \
57
- break; \
58
- case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
59
- function<int8_t>(__VA_ARGS__); \
60
- break; \
61
- case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
62
- function<uint8_t>(__VA_ARGS__); \
63
- break; \
64
- case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
65
- function<int16_t>(__VA_ARGS__); \
66
- break; \
67
- case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
68
- function<uint16_t>(__VA_ARGS__); \
69
- break; \
70
- case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
71
- function<int32_t>(__VA_ARGS__); \
72
- break; \
73
- case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
74
- function<uint32_t>(__VA_ARGS__); \
75
- break; \
76
- case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
77
- function<int64_t>(__VA_ARGS__); \
78
- break; \
79
- case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
80
- function<uint64_t>(__VA_ARGS__); \
81
- break; \
82
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
83
- function<MLFloat16>(__VA_ARGS__); \
84
- break; \
85
- case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
86
- function<BFloat16>(__VA_ARGS__); \
87
- break; \
88
- default: \
89
- ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
90
- }
91
-
92
- #define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
93
- switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
94
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
95
- retval = function<float>(__VA_ARGS__); \
96
- break; \
97
- case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
98
- retval = function<bool>(__VA_ARGS__); \
99
- break; \
100
- case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
101
- retval = function<double>(__VA_ARGS__); \
102
- break; \
103
- case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
104
- retval = function<std::string>(__VA_ARGS__); \
105
- break; \
106
- case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
107
- retval = function<int8_t>(__VA_ARGS__); \
108
- break; \
109
- case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
110
- retval = function<uint8_t>(__VA_ARGS__); \
111
- break; \
112
- case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
113
- retval = function<uint16_t>(__VA_ARGS__); \
114
- break; \
115
- case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
116
- retval = function<int16_t>(__VA_ARGS__); \
117
- break; \
118
- case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
119
- retval = function<int32_t>(__VA_ARGS__); \
120
- break; \
121
- case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
122
- retval = function<uint32_t>(__VA_ARGS__); \
123
- break; \
124
- case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
125
- retval = function<int64_t>(__VA_ARGS__); \
126
- break; \
127
- case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
128
- retval = function<uint64_t>(__VA_ARGS__); \
129
- break; \
130
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
131
- retval = function<MLFloat16>(__VA_ARGS__); \
132
- break; \
133
- case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
134
- retval = function<BFloat16>(__VA_ARGS__); \
135
- break; \
136
- default: \
137
- ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
138
- }
139
-
140
- ////////////////////////////////////////////////////////////////////////////////
141
- /// Use the following primitives if you have a few types to switch on so you
142
- // can write a short sequence of if/else statements.
143
-
144
- // This is a frequently used check so we make a separate utility function.
145
- inline bool IsDataTypeString(MLDataType dt_type) {
146
- auto prim_type = dt_type->AsPrimitiveDataType();
147
- return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
148
- }
149
-
150
- // Test if MLDataType is a concrete type of PrimitiveDataTypeBase
151
- // and it is T
152
- template <class T>
153
- inline bool IsPrimitiveDataType(MLDataType dt_type) {
154
- auto prim_type = dt_type->AsPrimitiveDataType();
155
- return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
156
- }
157
-
158
- // Use after AsPrimitiveDataType() is successful
159
- // Check if PrimitiveDataTypeBase is of type T
160
- template <class T>
161
- inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) {
162
- assert(prim_type != nullptr);
163
- return prim_type->GetDataType() == ToTensorProtoElementType<T>();
164
- }
165
-
166
- // This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226
167
- // GCC until very recently does not support template parameter pack expansion within lambda context.
168
- namespace mltype_dispatcher_internal {
169
-
170
- // T - type handled by this helper
171
- class CallableDispatchableHelper {
172
- int32_t dt_type_; // Type currently dispatched
173
- size_t called_;
174
-
175
- public:
176
- explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}
177
-
178
- // Must return integer to be in a expandable context
179
- template <class T, class Fn, class... Args>
180
- int Invoke(Fn&& fn, Args&&... args) {
181
- if (utils::ToTensorProtoElementType<T>() == dt_type_) {
182
- std::forward<Fn>(fn)(std::forward<Args>(args)...);
183
- ++called_;
184
- }
185
- return 0;
186
- }
187
-
188
- void CheckCalledOnce() {
189
- ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
190
- }
191
- };
192
-
193
- // Default policy is to throw an exception.
194
- // Other policies may set the second result argument accordingly.
195
- template <class Ret>
196
- struct UnsupportedTypeDefaultPolicy {
197
- void operator()(int32_t dt_type, Ret& /*result*/) const {
198
- ORT_THROW("Unsupported data type: ", dt_type);
199
- }
200
- };
201
-
202
- // Helper with the result type
203
- template <class Ret, class UnsupportedPolicy>
204
- class CallableDispatchableRetHelper {
205
- int32_t dt_type_; // Type currently dispatched
206
- size_t called_;
207
- Ret result_;
208
-
209
- public:
210
- explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {}
211
-
212
- Ret Get() {
213
- // No type was invoked
214
- if (called_ == 0) {
215
- UnsupportedPolicy()(dt_type_, result_);
216
- }
217
- return result_;
218
- }
219
-
220
- // Must return integer to be in a expandable context
221
- template <class T, class Fn, class... Args>
222
- int Invoke(Fn&& fn, Args&&... args) {
223
- if (utils::ToTensorProtoElementType<T>() == dt_type_) {
224
- result_ = std::forward<Fn>(fn)(std::forward<Args>(args)...);
225
- ++called_;
226
- }
227
- return 0;
228
- }
229
- };
230
-
231
- template <typename T>
232
- using TensorProtoElementTypeConstant =
233
- std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
234
-
235
- using UndefinedTensorProtoElementTypeConstant =
236
- std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
237
-
238
- } // namespace mltype_dispatcher_internal
239
-
240
- /**
241
- * This class helps to efficiently dispatch calls to implementation function
242
- * objects with a tensor element type template argument.
243
- *
244
- * The constructor accepts a value corresponding to a tensor element type.
245
- * For example, it can be obtained from:
246
- * input_tensor->GetElementType()
247
- *
248
- * The Invoke member functions will instantiate and invoke the provided
249
- * function object template, Fn. Fn must be default constructible. Fn must also
250
- * have a tensor element type template argument. This type template argument
251
- * will be the type that corresponds to the value given in the constructor.
252
- * These functions accept and forward arbitrary function arguments. They ensure
253
- * that Fn is called once with the type specified in the constructor.
254
- *
255
- * @tparam Types The types supported by the implementation. This should be a
256
- * set of ONNX tensor element types that are supported by ORT.
257
- */
258
- template <typename... Types>
259
- class MLTypeCallDispatcher {
260
- using SupportedTypeList = TypeList<Types...>;
261
- using SupportedTensorProtoElementTypeList =
262
- boost::mp11::mp_transform<
263
- mltype_dispatcher_internal::TensorProtoElementTypeConstant, SupportedTypeList>;
264
-
265
- static_assert(
266
- boost::mp11::mp_and<
267
- boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
268
- boost::mp11::mp_not<
269
- boost::mp11::mp_set_contains<
270
- SupportedTensorProtoElementTypeList,
271
- mltype_dispatcher_internal::UndefinedTensorProtoElementTypeConstant>>>::value,
272
- "Types must map to a unique set of ONNX tensor element data types supported by ORT.");
273
-
274
- int32_t dt_type_;
275
-
276
- public:
277
- /**
278
- * Constructor.
279
- * @param dt_type The value corresponding to the tensor element type to be
280
- * dispatched to. This can be obtained from
281
- * input_tensor->GetElementType() or
282
- * utils::ToTensorProtoElementType<T>().
283
- */
284
- explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {}
285
-
286
- /**
287
- * Invokes Fn<T> with the specified arguments.
288
- *
289
- * @tparam Fn The function object template.
290
- * @tparam Args The argument types.
291
- */
292
- template <template <typename...> class Fn, typename... Args>
293
- void Invoke(Args&&... args) const {
294
- InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
295
- }
296
-
297
- /**
298
- * Invokes Fn<..., T> with leading template arguments and the specified
299
- * arguments.
300
- *
301
- * @tparam Fn The function object template.
302
- * @tparam LeadingTemplateArgTypeList A type list of the leading template
303
- * arguments.
304
- * @tparam Args The argument types.
305
- */
306
- template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
307
- void InvokeWithLeadingTemplateArgs(Args&&... args) const {
308
- static_assert(
309
- boost::mp11::mp_is_list<LeadingTemplateArgTypeList>::value,
310
- "LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
311
-
312
- mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
313
-
314
- // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
315
- // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
316
- static_cast<void>(std::array<int, sizeof...(Types)>{
317
- helper.template Invoke<Types>(
318
- boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
319
- std::forward<Args>(args)...)...});
320
-
321
- // avoid "unused parameter" warning for the case where Types is empty
322
- static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
323
-
324
- helper.CheckCalledOnce();
325
- }
326
-
327
- /**
328
- * Invokes Fn<T> with the specified arguments and returns the result.
329
- *
330
- * @tparam Ret The return type. Fn should return a type convertible to Ret.
331
- * @tparam Fn The function object template.
332
- * @tparam Args The argument types.
333
- */
334
- template <class Ret, template <typename...> class Fn, typename... Args>
335
- Ret InvokeRet(Args&&... args) const {
336
- return InvokeRetWithUnsupportedPolicy<
337
- Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>>(
338
- std::forward<Args>(args)...);
339
- }
340
-
341
- /**
342
- * Invokes Fn<T> with the specified arguments and returns the result.
343
- *
344
- * @tparam Ret The return type. Fn should return a type convertible to Ret.
345
- * @tparam Fn The function object template.
346
- * @tparam UnsupportedPolicy The policy used to handle unsupported types.
347
- * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
348
- * for an example.
349
- * @tparam Args The argument types.
350
- */
351
- template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
352
- Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
353
- return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
354
- Ret, Fn, UnsupportedPolicy, TypeList<>>(
355
- std::forward<Args>(args)...);
356
- }
357
-
358
- /**
359
- * Invokes Fn<..., T> with leading template arguments and the specified
360
- * arguments and returns the result.
361
- *
362
- * @tparam Ret The return type. Fn should return a type convertible to Ret.
363
- * @tparam Fn The function object template.
364
- * @tparam LeadingTemplateArgTypeList A type list of the leading template
365
- * arguments.
366
- * @tparam Args The argument types.
367
- */
368
- template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
369
- Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
370
- return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
371
- Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
372
- std::forward<Args>(args)...);
373
- }
374
-
375
- /**
376
- * Invokes Fn<..., T> with leading template arguments and the specified
377
- * arguments and returns the result.
378
- *
379
- * @tparam Ret The return type. Fn should return a type convertible to Ret.
380
- * @tparam Fn The function object template.
381
- * @tparam UnsupportedPolicy The policy used to handle unsupported types.
382
- * See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
383
- * for an example.
384
- * @tparam LeadingTemplateArgTypeList A type list of the leading template
385
- * arguments.
386
- * @tparam Args The argument types.
387
- */
388
- template <class Ret,
389
- template <typename...> class Fn,
390
- class UnsupportedPolicy,
391
- typename LeadingTemplateArgTypeList,
392
- typename... Args>
393
- Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args&&... args) const {
394
- mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
395
-
396
- // given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
397
- // call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
398
- static_cast<void>(std::array<int, sizeof...(Types)>{
399
- helper.template Invoke<Types>(
400
- boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
401
- std::forward<Args>(args)...)...});
402
-
403
- // avoid "unused parameter" warning for the case where Types is empty
404
- static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
405
-
406
- return helper.Get();
407
- }
408
- };
409
-
410
- // the type MLTypeCallDispatcher<T...> given a type list L<T...>
411
- template <typename L>
412
- using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher, L>;
413
-
414
- namespace data_types_internal {
415
-
416
- enum class ContainerType : uint16_t {
417
- kUndefined = 0,
418
- kTensor = 1,
419
- kMap = 2,
420
- kSequence = 3,
421
- kOpaque = 4,
422
- kOptional = 5
423
- };
424
-
425
- class TypeNode {
426
- // type_ is a TypeProto value case enum
427
- // that may be a kTypeTensor, kTypeMap, kTypeSequence
428
- // prim_type_ is a TypeProto_DataType enum that has meaning
429
- // - for Tensor then prim_type_ is the contained type
430
- // - for Map prim_type is the key type. Next entry describes map value
431
- // - For sequence prim_type_ is not used and has no meaning. Next entry
432
- // describes the value for the sequence
433
- // Tensor is always the last entry as it describes a contained primitive type.
434
- ContainerType type_;
435
- uint16_t prim_type_;
436
-
437
- public:
438
- TypeNode(ContainerType type, int32_t prim_type) noexcept {
439
- type_ = type;
440
- prim_type_ = static_cast<uint16_t>(prim_type);
441
- }
442
-
443
- bool IsType(ContainerType type) const noexcept {
444
- return type_ == type;
445
- }
446
-
447
- bool IsPrimType(int32_t prim_type) const noexcept {
448
- return prim_type_ == static_cast<uint16_t>(prim_type);
449
- }
450
- };
451
-
452
- } // namespace data_types_internal
453
-
454
- ////////////////////////////////////////////////////////////////////
455
- /// Provides generic interface to test whether MLDataType is a Sequence,
456
- /// Map or an Opaque type including arbitrary recursive definitions
457
- /// without querying DataTypeImpl::GetType<T> for all known complex types
458
-
459
- // T is a sequence contained element type
460
- // If returns true then we know that the runtime
461
- // representation is std::vector<T>
462
- // T itself can be a runtime representation of another
463
- // sequence, map, opaque type or a tensor
464
- //
465
- // That is it can be std::vector or a std::map
466
- // If T is a primitive type sequence is tested whether it contains
467
- // tensors of that type
468
- //
469
- // If T is an opaque type, then it is only tested to be opaque but not exactly
470
- // a specific opaque type. To Test for a specific Opaque type use IsOpaqueType() below
471
- //
472
- // This class examines the supplied MLDataType and records
473
- // its information in a vector so any subsequent checks for Sequences and Maps
474
- // are quick.
475
- class ContainerChecker {
476
- using Cont = std::vector<data_types_internal::TypeNode>;
477
- Cont types_;
478
-
479
- // Default IsContainerOfType is for Opaque type
480
- template <class T>
481
- struct IsContainerOfType {
482
- static bool check(const Cont& c, size_t index) {
483
- if (index >= c.size()) {
484
- return false;
485
- }
486
- return c[index].IsType(data_types_internal::ContainerType::kOpaque);
487
- }
488
- };
489
-
490
- // Handles the case where sequence element is also a sequence
491
- template <class T>
492
- struct IsContainerOfType<std::vector<T>> {
493
- static bool check(const Cont& c, size_t index) {
494
- if (index >= c.size()) {
495
- return false;
496
- }
497
- if (c[index].IsType(data_types_internal::ContainerType::kSequence)) {
498
- ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
499
- constexpr int32_t prim_type = ToTensorProtoElementType<T>();
500
- // Check if this is a primitive type and it matches
501
- if constexpr(prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
502
- return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
503
- c[index].IsPrimType(prim_type);
504
- }
505
- else {
506
- // T is not primitive, check next entry for non-primitive proto
507
- return IsContainerOfType<T>::check(c, index);
508
- }
509
- }
510
- return false;
511
- }
512
- };
513
-
514
- template <class K, class V>
515
- struct IsContainerOfType<std::map<K, V>> {
516
- static bool check(const Cont& c, size_t index) {
517
- static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
518
- "Map Key can not be a non-primitive type");
519
- if (index >= c.size()) {
520
- return false;
521
- }
522
- if (!c[index].IsType(data_types_internal::ContainerType::kMap)) {
523
- return false;
524
- }
525
- constexpr int32_t key_type = ToTensorProtoElementType<K>();
526
- if (!c[index].IsPrimType(key_type)) {
527
- return false;
528
- }
529
- ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
530
- constexpr int32_t val_type = ToTensorProtoElementType<V>();
531
- if constexpr(val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
532
- return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
533
- c[index].IsPrimType(val_type);
534
- }
535
- else return IsContainerOfType<V>::check(c, index);
536
- }
537
- };
538
-
539
- public:
540
- explicit ContainerChecker(MLDataType);
541
- ~ContainerChecker() = default;
542
-
543
- bool IsMap() const noexcept {
544
- assert(!types_.empty());
545
- return types_[0].IsType(data_types_internal::ContainerType::kMap);
546
- }
547
-
548
- bool IsSequence() const noexcept {
549
- assert(!types_.empty());
550
- return types_[0].IsType(data_types_internal::ContainerType::kSequence);
551
- }
552
-
553
- template <class T>
554
- bool IsSequenceOf() const {
555
- assert(!types_.empty());
556
- return IsContainerOfType<std::vector<T>>::check(types_, 0);
557
- }
558
-
559
- template <class K, class V>
560
- bool IsMapOf() const {
561
- assert(!types_.empty());
562
- return IsContainerOfType<std::map<K, V>>::check(types_, 0);
563
- }
564
- };
565
-
566
- bool IsOpaqueType(MLDataType ml_type, const char* domain, const char* name);
567
-
568
- } // namespace utils
569
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/endian.h DELETED
@@ -1,27 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- namespace onnxruntime {
7
-
8
- // the semantics of this enum should match std::endian from C++20
9
- enum class endian {
10
- #if defined(_WIN32)
11
- little = 0,
12
- big = 1,
13
- native = little,
14
- #elif defined(__GNUC__) || defined(__clang__)
15
- little = __ORDER_LITTLE_ENDIAN__,
16
- big = __ORDER_BIG_ENDIAN__,
17
- native = __BYTE_ORDER__,
18
- #else
19
- #error onnxruntime::endian is not implemented in this environment.
20
- #endif
21
- };
22
-
23
- static_assert(
24
- endian::native == endian::little || endian::native == endian::big,
25
- "Only little-endian or big-endian native byte orders are supported.");
26
-
27
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/execution_provider.h DELETED
@@ -1,340 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #ifndef SHARED_PROVIDER
7
- #include <memory>
8
- #include <unordered_map>
9
- #include <unordered_set>
10
-
11
- #include "core/common/logging/logging.h"
12
- #include "core/common/status.h"
13
- #include "core/framework/data_transfer.h"
14
- #include "core/framework/tensor.h"
15
-
16
- namespace onnxruntime {
17
- class GraphViewer;
18
- struct ComputeCapability;
19
- class KernelRegistry;
20
- struct KernelCreateInfo;
21
- class Node;
22
- } // namespace onnxruntime
23
- #else
24
- #include <memory>
25
- #endif
26
-
27
- #include "core/common/basic_types.h"
28
- #include "core/common/profiler_common.h"
29
- #include "core/framework/allocatormgr.h"
30
- #include "core/framework/func_api.h"
31
- #include "core/framework/provider_options.h"
32
- #include "core/framework/stream_handles.h"
33
- #include "core/framework/tuning_context.h"
34
-
35
- namespace onnxruntime {
36
-
37
- /**
38
- Logical device representation.
39
- */
40
-
41
- // if we are export the fused function to dll, the function will still in the same binary as onnxruntime
42
- // use std function to give execution provider some chance to capture some state.
43
- using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
44
- using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
45
- using DestroyFunctionStateFunc = std::function<void(FunctionState)>;
46
-
47
- struct NodeComputeInfo {
48
- CreateFunctionStateFunc create_state_func;
49
- ComputeFunc compute_func;
50
- DestroyFunctionStateFunc release_state_func;
51
- };
52
-
53
- enum class DataLayout {
54
- NCHW,
55
- NHWC,
56
- NCHWC,
57
- };
58
-
59
- class IExecutionProvider {
60
- protected:
61
- IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
62
- : type_{type} {
63
- if (use_metadef_id_creator) {
64
- metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
65
- }
66
- }
67
-
68
- public:
69
- virtual ~IExecutionProvider() = default;
70
-
71
- /**
72
- Get all IAllocators for <*this> execution provider.
73
- */
74
- const std::vector<AllocatorPtr>& GetAllocators() const {
75
- return allocator_list_;
76
- }
77
-
78
- /**
79
- * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist
80
- */
81
- virtual AllocatorPtr GetAllocator(OrtMemType mem_type) const;
82
-
83
- /**
84
- * Returns a data transfer object that implements methods to copy to and
85
- * from this device.
86
- * If no copy is required for the successful operation of this provider,
87
- * return a nullptr.
88
- */
89
- virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
90
- return nullptr;
91
- }
92
-
93
- /**
94
- * Interface for performing kernel lookup within kernel registries.
95
- * Abstracts away lower-level details about kernel registries and kernel matching.
96
- */
97
- class IKernelLookup {
98
- public:
99
- /**
100
- * Given `node`, try to find a matching kernel for this EP.
101
- * The return value is non-null if and only if a matching kernel was found.
102
- */
103
- virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
104
-
105
- protected:
106
- ~IKernelLookup() = default;
107
- };
108
-
109
- /**
110
- Get execution provider's capability for the specified <graph>.
111
- Return a bunch of IndexedSubGraphs <*this> execution provider can run if
112
- the sub-graph contains only one node or can fuse to run if the sub-graph
113
- contains more than one node. The node indexes contained in sub-graphs may
114
- have overlap, and it's ONNXRuntime's responsibility to do the partition
115
- and decide whether a node will be assigned to <*this> execution provider.
116
- For kernels registered in a kernel registry, `kernel_lookup` must be used
117
- to find a matching kernel for this EP.
118
- */
119
- virtual std::vector<std::unique_ptr<ComputeCapability>>
120
- GetCapability(const onnxruntime::GraphViewer& graph_viewer,
121
- const IKernelLookup& kernel_lookup) const;
122
-
123
- /**
124
- Get kernel registry per execution provider type.
125
- The KernelRegistry share pointer returned is shared across sessions.
126
-
127
- NOTE: this approach was taken to achieve the following goals,
128
- 1. The execution provider type based kernel registry should be shared
129
- across sessions.
130
- Only one copy of this kind of kernel registry exists in ONNXRuntime
131
- with multiple sessions/models.
132
- 2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
133
- framework/session code.
134
- 3. onnxruntime (framework/session) does not depend on any specific
135
- execution provider lib.
136
- */
137
- virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }
138
-
139
- /**
140
- Get the device id of current execution provider
141
- */
142
- virtual int GetDeviceId() const { return 0; };
143
-
144
- /**
145
- Get execution provider's configuration options.
146
- */
147
- virtual ProviderOptions GetProviderOptions() const { return {}; }
148
-
149
- /**
150
- Returns an opaque handle whose exact type varies based on the provider
151
- and is interpreted accordingly by the corresponding kernel implementation.
152
- For Direct3D operator kernels, this may return an IUnknown supporting
153
- QueryInterface to ID3D12GraphicsCommandList1.
154
- */
155
- virtual const void* GetExecutionHandle() const noexcept {
156
- return nullptr;
157
- }
158
-
159
- /**
160
- @return type of the execution provider; should match that set in the node
161
- through the SetExecutionProvider API. Example valid return values are:
162
- kCpuExecutionProvider, kCudaExecutionProvider
163
- */
164
- const std::string& Type() const { return type_; }
165
-
166
- /**
167
- Blocks until the device has completed all preceding requested tasks.
168
- Currently this is primarily used by the IOBinding object to ensure that all
169
- inputs have been copied to the device before execution begins.
170
- */
171
- virtual common::Status Sync() const { return Status::OK(); }
172
-
173
- /**
174
- Called when InferenceSession::Run started
175
- NOTE that due to async execution in provider, the actual work of previous
176
- Run may not be finished on device This function should be regarded as the
177
- point after which a new Run would start to submit commands from CPU
178
- */
179
- virtual common::Status OnRunStart() { return Status::OK(); }
180
-
181
- /**
182
- Called when InferenceSession::Run ended
183
- NOTE that due to async execution in provider, the actual work of this Run
184
- may not be finished on device This function should be regarded as the point
185
- that all commands of current Run has been submmited by CPU
186
- */
187
- virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
188
-
189
- /**
190
- Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
191
- the provider. Currently only CUDA execution provider supports it.
192
- */
193
- virtual bool IsGraphCaptureEnabled() const { return false; }
194
-
195
- /**
196
- Indicate whether the graph has been captured and instantiated. Currently
197
- only CUDA execution provider supports it.
198
- */
199
- virtual bool IsGraphCaptured() const { return false; }
200
-
201
- /**
202
- Run the instantiated graph. Currently only CUDA execution provider supports
203
- it.
204
- */
205
- virtual common::Status ReplayGraph() { return Status::OK(); }
206
-
207
- /**
208
- Called when session creation is complete
209
- This provides an opportunity for execution providers to optionally synchronize and
210
- clean up its temporary resources to reduce memory and ensure the first run is fast.
211
- */
212
- virtual common::Status OnSessionInitializationEnd() { return Status::OK(); }
213
-
214
- void InsertAllocator(AllocatorPtr allocator);
215
- void ReplaceAllocator(AllocatorPtr allocator);
216
-
217
- struct FusedNodeAndGraph {
218
- const std::reference_wrapper<onnxruntime::Node> fused_node;
219
- // GraphViewer that filters the full graph to the nodes that are covered by 'node'
220
- const std::reference_wrapper<GraphViewer> filtered_graph;
221
- };
222
-
223
- // Fusion approach that is suppported
224
- // !!! The "Function" FusionStyle is deprecated.
225
- // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
226
- enum class FusionStyle {
227
- // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
228
- // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
229
- // A GraphProto can be produced from the Node body.
230
- Function,
231
-
232
- // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
233
- // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
234
- // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
235
- // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
236
- // and can be supported in a minimal build.
237
- FilteredGraphViewer
238
- };
239
-
240
- virtual FusionStyle GetFusionStyle() const {
241
- // All the ORT build in EP has migrate to FilteredGraphViewer style.
242
- // For newer EPs, please avoid use Function style as it is deprecated.
243
- return FusionStyle::FilteredGraphViewer;
244
- }
245
-
246
- #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
247
- /**
248
- Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
249
- return create_state/compute/release_state func for each node.
250
- @remarks This is now the default interface when execution provider wants to compile nodes
251
- for both minimal build and complete ort build.
252
-
253
- Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
254
- as it is only valid for the duration of the call to Compile.
255
- */
256
- virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
257
- std::vector<NodeComputeInfo>& node_compute_funcs);
258
-
259
- #endif
260
-
261
- void SetLogger(const logging::Logger* logger) {
262
- logger_ = logger;
263
- }
264
-
265
- const logging::Logger* GetLogger() const {
266
- return logger_;
267
- }
268
-
269
- /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
270
- The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
271
- @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
272
- @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
273
- This is created using the model path if available,
274
- or the model input names and the output names from all nodes in the main graph.
275
- @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
276
- compiled kernels, so the name must be unique and deterministic across models and sessions.
277
- NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
278
- virtual, and ModelMetadefIdGenerator but be defined in the header as well.
279
- */
280
- virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;
281
-
282
- /**
283
- Register allocators for EP, potentially re-using existing allocators for a device from allocator_manager.
284
- If the EP implements this it should generally delay creating any allocators until this is called.
285
- */
286
- virtual void RegisterAllocator(AllocatorManager& /*allocator_manager*/);
287
-
288
- virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
289
- return {};
290
- }
291
-
292
- virtual DataLayout GetPreferredLayout() const {
293
- // NCHW is the default ONNX standard data layout. So default to it.
294
- // EPs which prefer a different layout should override to return their preferred layout.
295
- return DataLayout::NCHW;
296
- }
297
-
298
- virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/) const {}
299
-
300
- /** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
301
- */
302
- virtual bool ConcurrentRunSupported() const { return true; }
303
-
304
- /**
305
- * Return the tuning context which holds all TunableOp state.
306
- */
307
- virtual ITuningContext* GetTuningContext() const {
308
- return nullptr;
309
- }
310
-
311
- private:
312
- const std::string type_;
313
-
314
- // allocator lookup is done by combining the device id and OrtMemType.
315
- // there's also an implicit connection to the underlying OrtDevice involved that is dependent on the EP.
316
- // e.g. for a CPU based EP, 'default' memory is a CPU device, and for a GPU based EP 'default' memory is a
317
- // GPU device.
318
- using AllocatorMap = std::unordered_map<int, AllocatorPtr>;
319
- AllocatorMap allocators_;
320
-
321
- // It will be set when this object is registered to a session
322
- const logging::Logger* logger_ = nullptr;
323
- // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
324
- // contains the same instances as allocators_
325
- std::vector<AllocatorPtr> allocator_list_;
326
-
327
- // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
328
- // multiple sessions.
329
- class ModelMetadefIdGenerator {
330
- public:
331
- int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);
332
-
333
- private:
334
- std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
335
- std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
336
- };
337
-
338
- std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
339
- };
340
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/float16.h DELETED
@@ -1,159 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
- #pragma once
4
-
5
- #include "endian.h"
6
- #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
7
- #include "cuda_bf16.h"
8
- #endif
9
-
10
- #if !defined(__CUDACC__) && !defined(__HIPCC__)
11
- #include "core/common/narrow.h"
12
- #endif
13
-
14
- #include "core/common/common.h"
15
-
16
- namespace onnxruntime {
17
-
18
- #if defined(__CUDACC__) || defined(__HIPCC__)
19
- #define ORT_HOST_DEVICE __host__ __device__
20
- #else
21
- #define ORT_HOST_DEVICE
22
- #endif
23
-
24
- // MLFloat16
25
- struct MLFloat16 {
26
- uint16_t val{0};
27
-
28
- MLFloat16() = default;
29
- explicit constexpr MLFloat16(uint16_t x) : val(x) {}
30
- explicit MLFloat16(float f);
31
-
32
- float ToFloat() const;
33
-
34
- operator float() const { return ToFloat(); }
35
- };
36
-
37
- inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
38
- inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
39
- inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }
40
-
41
- // BFloat16
42
- struct BFloat16 {
43
- uint16_t val{0};
44
- #if defined(__HIP__)
45
- ORT_HOST_DEVICE BFloat16() = default;
46
- #else
47
- BFloat16() = default;
48
- #endif
49
-
50
- struct FromBitsT {};
51
- static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
52
- constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits) {}
53
-
54
- inline ORT_HOST_DEVICE BFloat16(float v) {
55
- #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
56
- val = __bfloat16_as_ushort(__float2bfloat16(v));
57
- #elif defined(__HIP__)
58
- // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
59
- if (v != v) { // isnan
60
- val = UINT16_C(0x7FC0);
61
- } else {
62
- union {
63
- uint32_t U32;
64
- float F32;
65
- };
66
-
67
- F32 = v;
68
- uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
69
- val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
70
- }
71
- #else
72
- if constexpr(endian::native == endian::little) {
73
- std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
74
- }
75
- else {
76
- std::memcpy(&val, &v, sizeof(uint16_t));
77
- }
78
- #endif
79
- }
80
-
81
- inline ORT_HOST_DEVICE float ToFloat() const {
82
- #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
83
- return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
84
- #elif defined(__HIP__)
85
- // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
86
- float result = 0;
87
- uint32_t tmp = val;
88
- tmp <<= 16;
89
- float* tempRes = reinterpret_cast<float*>(&tmp);
90
- result = *tempRes;
91
- return result;
92
- #else
93
- float result;
94
- char* const first = reinterpret_cast<char*>(&result);
95
- char* const second = first + sizeof(uint16_t);
96
- if constexpr(endian::native == endian::little) {
97
- std::memset(first, 0, sizeof(uint16_t));
98
- std::memcpy(second, &val, sizeof(uint16_t));
99
- }
100
- else {
101
- std::memcpy(first, &val, sizeof(uint16_t));
102
- std::memset(second, 0, sizeof(uint16_t));
103
- }
104
- return result;
105
- #endif
106
- }
107
-
108
- inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
109
-
110
- #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
111
- ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
112
- explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
113
- #endif
114
- };
115
-
116
- inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& right) { return left.val == right.val; }
117
- inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
118
- inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
119
-
120
-
121
- // User defined suffixes to make it easier to declare
122
- // initializers with MLFloat16 and BFloat16 from unsigned short
123
- // E.g 10_f16 or 10_b16
124
- #if !defined(__CUDACC__) && !defined(__HIPCC__)
125
- inline MLFloat16 operator"" _f16(unsigned long long int v) {
126
- return MLFloat16(narrow<uint16_t>(v));
127
- }
128
-
129
- inline MLFloat16 operator"" _fp16(long double v) {
130
- return MLFloat16(static_cast<float>(v));
131
- }
132
-
133
- inline BFloat16 operator"" _b16(unsigned long long int v) {
134
- return BFloat16(narrow<uint16_t>(v), BFloat16::FromBits());
135
- }
136
-
137
- inline BFloat16 operator"" _bfp16(long double v) {
138
- return BFloat16(static_cast<float>(v));
139
- }
140
-
141
- #endif
142
-
143
- inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
144
- auto src = blf;
145
- auto d = flt;
146
- for (; size != 0; ++src, ++d, --size) {
147
- *d = src->ToFloat();
148
- }
149
- }
150
-
151
- inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
152
- auto src = flt;
153
- auto d = blf;
154
- for (; size != 0; ++src, ++d, --size) {
155
- new (d) BFloat16(*src);
156
- }
157
- }
158
-
159
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/framework_common.h DELETED
@@ -1,22 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string>
7
- #include <unordered_map>
8
- #include <vector>
9
- #include "run_options.h"
10
-
11
- namespace onnxruntime { // forward declarations
12
- class Model;
13
- class GraphTransformer;
14
- class NodeArg;
15
- } // namespace onnxruntime
16
-
17
- namespace onnxruntime {
18
- using InputDefList = std::vector<const onnxruntime::NodeArg*>;
19
- using OutputDefList = std::vector<const onnxruntime::NodeArg*>;
20
-
21
- using NameMLValMap = std::unordered_map<std::string, OrtValue>;
22
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/func_api.h DELETED
@@ -1,27 +0,0 @@
1
- #pragma once
2
- #include "core/common/status.h"
3
- using onnxruntime::common::Status; // TODO: Needed by WinML, but shouldn't be put into the global namespace like this
4
-
5
- namespace onnxruntime {
6
-
7
- // AllocateFunc(void* handle, size_t alignment, size_t size)
8
- using AllocateFunc = void* (*)(void*, size_t, size_t);
9
- using DestroyFunc = void (*)(void*, void*);
10
- using AllocatorHandle = void*;
11
-
12
- typedef struct {
13
- //right now we only include allocation for host memory
14
- AllocateFunc allocate_func;
15
- DestroyFunc release_func;
16
- AllocatorHandle allocator_handle;
17
- const char* node_name;
18
- } ComputeContext;
19
-
20
- using FunctionState = void*;
21
- // take the ComputeContext, and create a function state.
22
- using CreateFunctionStateC = int (*)(ComputeContext*, FunctionState*);
23
- // pass in the function state and input/output tensors, perform compute and return status
24
- using ComputeFuncC = common::Status (*)(FunctionState, const OrtApi*, OrtKernelContext*);
25
- // release the function state.
26
- using DestroyFunctionStateC = void (*)(FunctionState);
27
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/kernel_def_builder.h DELETED
@@ -1,353 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <limits.h>
7
- #include <memory>
8
- #include <optional>
9
- #include <string>
10
- #include <unordered_map>
11
- #include <vector>
12
-
13
- #include "core/common/common.h"
14
- #include "core/framework/allocator.h"
15
- #include "core/framework/data_types.h"
16
- #include "core/graph/basic_types.h"
17
-
18
- namespace onnxruntime {
19
- class KernelDefBuilder;
20
-
21
- typedef std::map<size_t, OrtMemType> MemTypeMap;
22
-
23
- class KernelDef {
24
- private:
25
- // note that input/output might be on CPU implicitly when the node is from CPU execution provider
26
- constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
27
- return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
28
- }
29
-
30
- public:
31
- explicit KernelDef() = default;
32
-
33
- const std::string& OpName() const {
34
- return op_name_;
35
- }
36
-
37
- const std::string& Domain() const {
38
- return op_domain_;
39
- }
40
-
41
- void SinceVersion(/*out*/ int* start, /*out*/ int* end) const {
42
- *start = op_since_version_start_;
43
- *end = op_since_version_end_;
44
- }
45
-
46
- #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
47
- const std::pair<int, int> SinceVersion() const {
48
- return std::pair<int, int>(op_since_version_start_, op_since_version_end_);
49
- }
50
- #endif
51
-
52
- onnxruntime::ProviderType Provider() const {
53
- return provider_type_;
54
- }
55
-
56
- // type constraints with types supported in this build
57
- const std::unordered_map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
58
- return type_constraints_;
59
- }
60
-
61
- const std::vector<std::pair<int, int>>& MayInplace() const {
62
- return inplace_map_;
63
- }
64
-
65
- const std::vector<std::pair<int, int>>& Alias() const {
66
- return alias_map_;
67
- }
68
-
69
- const std::optional<std::pair<int, int>>& VariadicAlias() const {
70
- return variadic_alias_offsets_;
71
- }
72
-
73
- OrtMemType InputMemoryType(size_t input_index) const {
74
- auto it = input_memory_type_args_.find(input_index);
75
- if (it == input_memory_type_args_.end())
76
- return default_inputs_mem_type_;
77
- return it->second;
78
- }
79
-
80
- bool IsInputOnCpu(size_t input_index) const { return MemTypeOnCpuExplicitly(InputMemoryType(input_index)); }
81
-
82
- bool IsOutputOnCpu(size_t output_index) const { return MemTypeOnCpuExplicitly(OutputMemoryType(output_index)); }
83
-
84
- bool AllocateInputsContiguously() const { return allocate_inputs_contiguously_; }
85
-
86
- bool HasExternalOutputs() const { return external_outputs_; }
87
-
88
- #ifdef ENABLE_STRIDED_TENSORS
89
- const std::vector<int>& MayStridedInput() const { return may_strided_inputs_; }
90
- const std::vector<std::pair<int, int>>& MayStridedOutput() const { return may_strided_output_map_; }
91
- #endif
92
-
93
- OrtMemType OutputMemoryType(size_t output_index) const {
94
- auto it = output_memory_type_args_.find(output_index);
95
- if (it == output_memory_type_args_.end())
96
- return default_outputs_mem_type_;
97
- return it->second;
98
- }
99
-
100
- int ExecQueueId() const {
101
- return exec_queue_id_;
102
- }
103
-
104
- bool IsConflict(const KernelDef& other) const;
105
-
106
- private:
107
- friend class KernelDefBuilder;
108
-
109
- // The operator name supported by <*this> kernel..
110
- std::string op_name_;
111
-
112
- // The operator since_version range supported by <*this> kernel.
113
- // A kernel could support an operator definition between <op_since_version_start>
114
- // and <op_since_version_end> (inclusive).
115
- int op_since_version_start_ = 1;
116
- int op_since_version_end_ = INT_MAX;
117
-
118
- // The operator domain supported by <*this> kernel.
119
- // Default to 'onnxruntime::kOnnxDomain'.
120
- // Please note the behavior of std::string("") and std::string() are different
121
- std::string op_domain_;
122
-
123
- // The type of the execution provider.
124
- std::string provider_type_;
125
-
126
- // The data types that are supported in this build (enabled) for inputs/outputs.
127
- // Key is input/output/type constraint name defined in op schema, Value is supported types.
128
- std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;
129
-
130
- // An element <i, j> means that output j reuses the memory of input i.
131
- std::vector<std::pair<int, int>> inplace_map_;
132
-
133
- // An element <i, j> means that output j is an alias of input i.
134
- std::vector<std::pair<int, int>> alias_map_;
135
-
136
- // This variable stores <input_offset, output_offset> for the variadic alias mapping
137
- // output 'i + output_offset' is an alias of input 'i + input_offset' for all i >= 0
138
- std::optional<std::pair<int, int>> variadic_alias_offsets_;
139
-
140
- // Require input tensors to be allocated contiguously.
141
- bool allocate_inputs_contiguously_ = false;
142
-
143
- // Whether the outputs are from external.
144
- bool external_outputs_ = false;
145
-
146
- #ifdef ENABLE_STRIDED_TENSORS
147
- // An element i means i-th input can be strided tensor.
148
- std::vector<int> may_strided_inputs_;
149
-
150
- // An element <i, j> means j-th output can be a strided tensor, which share the data from i-th input.
151
- std::vector<std::pair<int, int>> may_strided_output_map_;
152
- #endif
153
-
154
- // The memory types of inputs/outputs of this kernel
155
- MemTypeMap input_memory_type_args_;
156
- MemTypeMap output_memory_type_args_;
157
-
158
- // execution command queue id, 0 for default queue in execution provider
159
- int exec_queue_id_ = 0;
160
- // Default memory type for all inputs
161
- OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
162
- // Default memory type for all outputs
163
- OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};
164
- };
165
-
166
- class KernelDefBuilder {
167
- public:
168
- static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }
169
-
170
- explicit KernelDefBuilder()
171
- : kernel_def_(std::make_unique<KernelDef>()) {}
172
-
173
- KernelDefBuilder& SetName(const std::string& op_name);
174
- KernelDefBuilder& SetName(const char* op_name);
175
-
176
- KernelDefBuilder& SetDomain(const std::string& domain);
177
- KernelDefBuilder& SetDomain(const char* domain);
178
-
179
- /**
180
- This kernel supports operator definition since <since_version> (to latest).
181
- */
182
- KernelDefBuilder& SinceVersion(int since_version) {
183
- kernel_def_->op_since_version_start_ = since_version;
184
- return *this;
185
- }
186
-
187
- /**
188
- The start and end version should be set accordingly per version range for
189
- each domain registered in OpSchemaRegistry::DomainToVersionRange in
190
- \onnxruntime\onnxruntime\core\graph\op.h as below.
191
- Key: domain. Value: <lowest version, highest version> pair.
192
- std::unordered_map<std::string, std::pair<int, int>> map_;
193
- */
194
- KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
195
- kernel_def_->op_since_version_start_ = since_version_start;
196
- kernel_def_->op_since_version_end_ = since_version_end;
197
- return *this;
198
- }
199
-
200
- /**
201
- The execution provider type of the kernel.
202
- */
203
- KernelDefBuilder& Provider(ProviderType provider_type);
204
- KernelDefBuilder& Provider(const char* provider_type);
205
-
206
- /**
207
- Specify the set of types that this kernel supports. A further restriction
208
- of the set of types specified in the op schema.
209
-
210
- @param arg_name The arg name can be either op formal parameter name, say "X", or type
211
- argument name specified in op schema, say "T".
212
- @param types The types that are supported in this build.
213
- */
214
- KernelDefBuilder& TypeConstraint(const std::string& arg_name, std::vector<MLDataType> types);
215
- KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types);
216
-
217
- /**
218
- Like TypeConstraint but supports just a single type.
219
- */
220
- KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType type);
221
- KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type);
222
-
223
- /**
224
- Inplace mapping from inputs to outputs allowed.
225
- It means that uplayer runtime could do memory in-place optimization
226
- as it will not impact the correctness of this kernel.
227
- */
228
- KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces);
229
- KernelDefBuilder& MayInplace(int input_index, int output_index);
230
-
231
- /**
232
- Alias mapping from inputs to outputs. Different from Inplace that the
233
- content of the tensor is not changed. This is to take care of operators
234
- such as Identity and Reshape.
235
- */
236
- KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases);
237
- KernelDefBuilder& Alias(int input_index, int output_index);
238
-
239
- /**
240
- Apply variadic number of alias mapping from inputs to outputs.
241
- This is effectively applying Alias(i + input_offset, i + output_offset) for i >= 0
242
- */
243
- KernelDefBuilder& VariadicAlias(int input_offset, int output_offset);
244
-
245
- /**
246
- Specify that this kernel requires input tensors to be allocated
247
- contiguously. This allows kernels to execute as a single large
248
- computation, rather than numerous smaller computations.
249
- */
250
- KernelDefBuilder& AllocateInputsContiguously() {
251
- kernel_def_->allocate_inputs_contiguously_ = true;
252
- return *this;
253
- }
254
-
255
- /**
256
- Specify that this kernel's output buffers are passed from external,
257
- i.e. not created or managed by ORT's memory allocator.
258
- */
259
- KernelDefBuilder& ExternalOutputs() {
260
- kernel_def_->external_outputs_ = true;
261
- return *this;
262
- }
263
-
264
- #ifdef ENABLE_STRIDED_TENSORS
265
- /**
266
- Specify that the input_index-th input can be strided tensor.
267
- */
268
- KernelDefBuilder& MayStridedInput(int input_index);
269
-
270
- /**
271
- Specify that the output_index-th output can be strided tensor, and share the data
272
- from input_index-th input.
273
- */
274
- KernelDefBuilder& MayStridedOutput(int input_index, int output_index);
275
- #endif
276
-
277
- /**
278
- Specify that this kernel requires an input arg
279
- in certain memory type (instead of the default, device memory).
280
- */
281
- KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
282
- kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
283
- return *this;
284
- }
285
-
286
- /**
287
- Specify that this kernel requires input arguments
288
- in certain memory type (instead of the default, device memory).
289
- */
290
- KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
291
- for (auto input_index : input_indexes) {
292
- kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
293
- }
294
- return *this;
295
- }
296
-
297
- /**
298
- Specify that this kernel provides an output arg
299
- in certain memory type (instead of the default, device memory).
300
- */
301
- KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
302
- kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
303
- return *this;
304
- }
305
-
306
- /**
307
- Specify that this kernel provides an output arguments
308
- in certain memory type (instead of the default, device memory).
309
- */
310
- KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {
311
- for (auto output_index : output_indexes) {
312
- kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
313
- }
314
- return *this;
315
- }
316
-
317
- /**
318
- Specify that this kernel runs on which execution queue in the provider
319
- */
320
- KernelDefBuilder& ExecQueueId(int queue_id) {
321
- kernel_def_->exec_queue_id_ = queue_id;
322
- return *this;
323
- }
324
-
325
- /**
326
- Specify the default inputs memory type, if not specified, it is DefaultMemory
327
- */
328
- KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
329
- kernel_def_->default_inputs_mem_type_ = mem_type;
330
- return *this;
331
- }
332
-
333
- /**
334
- Specify the default outputs memory type, if not specified, it is DefaultMemory
335
- */
336
- KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
337
- kernel_def_->default_outputs_mem_type_ = mem_type;
338
- return *this;
339
- }
340
-
341
- /**
342
- Return the kernel definition, passing ownership of the KernelDef to the caller
343
- */
344
- std::unique_ptr<KernelDef> Build() {
345
- return std::move(kernel_def_);
346
- }
347
-
348
- private:
349
- // we own the KernelDef until Build() is called.
350
- std::unique_ptr<KernelDef> kernel_def_;
351
- };
352
-
353
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/kernel_registry.h DELETED
@@ -1,91 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string_view>
7
-
8
- #include "core/framework/op_kernel.h"
9
-
10
- namespace onnxruntime {
11
-
12
- using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
13
- using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
14
-
15
- class IKernelTypeStrResolver;
16
-
17
- /**
18
- * Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
19
- */
20
- class KernelRegistry {
21
- public:
22
- KernelRegistry() = default;
23
-
24
- // Register a kernel with kernel definition and function to create the kernel.
25
- Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
26
-
27
- Status Register(KernelCreateInfo&& create_info);
28
-
29
- // TODO(edgchen1) for TryFindKernel(), consider using `out` != nullptr as indicator of whether kernel was found and
30
- // Status as an indication of failure
31
-
32
- // Check if an execution provider can create kernel for a node and return the kernel if so
33
- Status TryFindKernel(const Node& node, ProviderType exec_provider,
34
- const IKernelTypeStrResolver& kernel_type_str_resolver,
35
- const KernelCreateInfo** out) const;
36
-
37
- static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
38
- ProviderType exec_provider,
39
- const IKernelTypeStrResolver& kernel_type_str_resolver) {
40
- const KernelCreateInfo* info;
41
- Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
42
- return st.IsOK();
43
- }
44
-
45
- #if !defined(ORT_MINIMAL_BUILD)
46
- // Find KernelCreateInfo in instant mode
47
- Status TryFindKernel(const std::string& op_name, const std::string& domain, const int& version,
48
- const std::unordered_map<std::string, MLDataType>& type_constraints,
49
- ProviderType exec_provider, const KernelCreateInfo** out) const;
50
- #endif // !defined(ORT_MINIMAL_BUILD)
51
-
52
- bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
53
-
54
- #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
55
- // This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
56
- const KernelCreateMap& GetKernelCreateMap() const {
57
- return kernel_creator_fn_map_;
58
- }
59
- #endif
60
-
61
- private:
62
- // Check whether the types of inputs/outputs of the given node match the extra
63
- // type-constraints of the given kernel. This serves two purposes: first, to
64
- // select the right kernel implementation based on the types of the arguments
65
- // when we have multiple kernels, e.g., Clip<float> and Clip<int>; second, to
66
- // accommodate (and check) mapping of ONNX (specification) type to the onnxruntime
67
- // implementation type (e.g., if we want to implement ONNX's float16 as a regular
68
- // float in onnxruntime). (The second, however, requires a globally uniform mapping.)
69
- //
70
- // Note that this is not intended for type-checking the node against the ONNX
71
- // type specification of the corresponding op, which is done before this check.
72
- static bool VerifyKernelDef(const Node& node,
73
- const KernelDef& kernel_def,
74
- const IKernelTypeStrResolver& kernel_type_str_resolver,
75
- std::string& error_str);
76
-
77
- static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) {
78
- std::string key(op_name);
79
- // use the kOnnxDomainAlias of 'ai.onnx' instead of kOnnxDomain's empty string
80
- key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
81
- return key;
82
- }
83
-
84
- static std::string GetMapKey(const KernelDef& kernel_def) {
85
- return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
86
- }
87
- // Kernel create function map from op name to kernel creation info.
88
- // key is opname+domain_name+provider_name
89
- KernelCreateMap kernel_creator_fn_map_;
90
- };
91
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/op_kernel.h DELETED
@@ -1,387 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "boost/mp11.hpp"
7
-
8
- // It is safe to include the below header even if SHARED_PROVIDER macro is enabled
9
- // as it doesn't include any pb headers.
10
- #include "core/framework/prepacked_weights_container.h"
11
-
12
- #ifndef SHARED_PROVIDER
13
- #include <functional>
14
- #include "core/common/exceptions.h"
15
- #include "core/common/logging/logging.h"
16
- #include "core/common/status.h"
17
- #include "core/framework/execution_provider.h"
18
- #include "core/framework/kernel_def_builder.h"
19
- #include "core/framework/ort_value.h"
20
- #include "core/framework/op_kernel_info.h"
21
- #include "core/framework/op_node_proto_helper.h"
22
- #include "core/framework/tensor.h"
23
- #include "core/framework/sparse_tensor.h"
24
- #include "core/graph/constants.h"
25
- #include "core/graph/graph_viewer.h"
26
- #if !defined(ORT_MINIMAL_BUILD)
27
- #include "onnx/defs/schema.h"
28
- #else
29
- #include "onnx/defs/data_type_utils.h"
30
- #endif
31
- #include "onnx/onnx_pb.h"
32
- #include "onnx/onnx-operators_pb.h"
33
- #include "core/common/gsl.h"
34
- namespace onnxruntime {
35
- class OpKernelContext;
36
- }
37
- #endif
38
-
39
- namespace onnxruntime {
40
-
41
- std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
42
-
43
- class OpKernel {
44
- public:
45
- using DoneCallback = std::function<void()>;
46
-
47
- explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
48
- virtual ~OpKernel() = default;
49
-
50
- const onnxruntime::Node& Node() const;
51
- const onnxruntime::KernelDef& KernelDef() const;
52
-
53
- [[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;
54
-
55
- [[nodiscard]] virtual bool IsAsync() const {
56
- // by default all kernels are sync version.
57
- return false;
58
- }
59
-
60
- [[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
61
- ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
62
- }
63
-
64
- // Override this function to PrePack initialized constant tensor to the format as needed.
65
- // For example, MatMul kernel can pack the input B if it is constant like code below.
66
- // Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
67
- // /*out*/ PrePackedWeights* prepacked_weight_for_caching,
68
- // AllocatorPtr alloc) override {
69
- // is_packed = false;
70
- // if (input_idx == 1) {
71
- // is_packed = true;
72
- // this.Pack(tensor, this.buffer_, alloc);
73
- // if (prepacked_weight_for_caching) {
74
- // // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
75
- // }
76
- // }
77
- // return Status::OK();
78
- // }
79
- // Please refer to MatMulIntegerToFloatBase for a complete example
80
- // @param tensor: The initialized constant tensor
81
- // @param input_idx: The input index of the tensor in this kernel
82
- // @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
83
- // weights' buffers. The alloc that the PrePack() method will receive will be either
84
- // the allocator tied to the session if the kernel owns the pre-packed buffer or an
85
- // allocator shared between sessions if the pre-packed buffer is to be shared across sessions
86
- // (i.e.) the kernel does not own the buffer.
87
- // @param is_packed: Set it to true if the kernel packed the tensor or to false
88
- // The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
89
- // and the original initialized constant tensor will be released and not accessible anymore in
90
- // the Compute function.
91
- // @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
92
- // are meant to be stored in a shared container.
93
-
94
- virtual Status
95
- PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
96
- /*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
97
- is_packed = false;
98
- return Status::OK();
99
- }
100
-
101
- // Override this function to use provided pre-packed weight.
102
- // Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
103
- // int input_idx,
104
- // /*out*/ bool& used_shared_buffers) {
105
- // used_shared_buffers = true;
106
- // this.buffer_ = std::move(prepacked_buffers[0]);
107
- // return Status::OK();
108
- // }
109
- // Please refer to MatMulIntegerToFloatBase for a complete example
110
- // @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
111
- // (Sometimes a single constant initializer may have multiple pre-packed buffers associated
112
- // with it and it upto the kernel developer to store it in any order of their choice in PrePack()
113
- // and must use the same order for retrieval in UseSharedPrePackedBuffers().
114
- // @param input_idx: The input index of the tensor in this kernel
115
- // @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
116
- // that the provided weight has been used by the kernel.
117
- virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
118
- int /*input_idx*/,
119
- /*out*/ bool& used_shared_buffers) {
120
- used_shared_buffers = false;
121
- return Status::OK();
122
- }
123
-
124
- const OrtMemoryInfo& Allocator(OrtMemType mem_type) const;
125
- const OpKernelInfo& Info() const {
126
- return *op_kernel_info_;
127
- }
128
-
129
- private:
130
- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
131
- std::unique_ptr<OpKernelInfo> op_kernel_info_;
132
- };
133
- class FuncManager;
134
- using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
135
- using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::type;
136
-
137
- struct KernelCreateInfo {
138
- std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
139
- KernelCreateFn kernel_create_func;
140
- Status status;
141
-
142
- KernelCreateInfo(std::unique_ptr<KernelDef> definition,
143
- KernelCreateFn create_func)
144
- : kernel_def(std::move(definition)),
145
- kernel_create_func(create_func) {}
146
-
147
- KernelCreateInfo(KernelCreateInfo&& other) noexcept
148
- : kernel_def(std::move(other.kernel_def)),
149
- kernel_create_func(std::move(other.kernel_create_func)) {}
150
-
151
- KernelCreateInfo() = default;
152
- };
153
-
154
- // Forward declarations for the non-specialized BuildKernelCreateInfo method.
155
- template <typename T>
156
- KernelCreateInfo BuildKernelCreateInfo();
157
-
158
- namespace ml {
159
- template <typename T>
160
- KernelCreateInfo BuildKernelCreateInfo();
161
- } // namespace ml
162
-
163
- namespace contrib {
164
- template <typename T>
165
- KernelCreateInfo BuildKernelCreateInfo();
166
- } // namespace contrib
167
-
168
- namespace contrib {
169
- namespace cuda {
170
- template <typename T>
171
- KernelCreateInfo BuildKernelCreateInfo();
172
- } // namespace cuda
173
- } // namespace contrib
174
-
175
- namespace contrib {
176
- namespace rocm {
177
- template <typename T>
178
- KernelCreateInfo BuildKernelCreateInfo();
179
- } // namespace rocm
180
- } // namespace contrib
181
-
182
- namespace contrib {
183
- namespace snpe {
184
- template <typename T>
185
- KernelCreateInfo BuildKernelCreateInfo();
186
- } // namespace snpe
187
- } // namespace contrib
188
-
189
- using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
190
-
191
- // Naming convention for operator kernel classes
192
- #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
193
- provider##_##name##_##domain##_ver##ver
194
-
195
- #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
196
- ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
197
-
198
- #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
199
- ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
200
-
201
- #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
202
- ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
203
-
204
- #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
205
- class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
206
- template <> \
207
- KernelCreateInfo \
208
- BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
209
- return KernelCreateInfo( \
210
- builder.SetName(#name) \
211
- .SetDomain(domain) \
212
- .SinceVersion(ver) \
213
- .Provider(provider) \
214
- .Build(), \
215
- static_cast<KernelCreatePtrFn>( \
216
- [](FuncManager&, \
217
- const OpKernelInfo& info, \
218
- std::unique_ptr<OpKernel>& out) -> Status { \
219
- out = std::make_unique<__VA_ARGS__>(info); \
220
- return Status::OK(); \
221
- })); \
222
- }
223
-
224
- #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
225
- provider##_##name##_##domain##_ver##startver##_##endver
226
-
227
- #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
228
- ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
229
-
230
- #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
231
- ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
232
-
233
- #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
234
- class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
235
- template <> \
236
- KernelCreateInfo \
237
- BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
238
- return KernelCreateInfo( \
239
- builder.SetName(#name) \
240
- .SetDomain(domain) \
241
- .SinceVersion(startver, endver) \
242
- .Provider(provider) \
243
- .Build(), \
244
- static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
245
- }
246
-
247
- #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
248
- provider##_##name##_##domain##_ver##ver##_##type
249
-
250
- #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
251
- ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
252
-
253
- #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
254
- ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
255
-
256
- #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
257
- ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
258
-
259
- #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
260
- class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
261
- template <> \
262
- KernelCreateInfo \
263
- BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
264
- return KernelCreateInfo( \
265
- builder.SetName(#name) \
266
- .SetDomain(domain) \
267
- .SinceVersion(ver) \
268
- .Provider(provider) \
269
- .Build(), \
270
- static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
271
- }
272
-
273
- #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
274
- provider##_##name##_##domain##_ver##ver##_##type1##_##type2
275
-
276
- #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
277
- class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
278
- template <> \
279
- KernelCreateInfo \
280
- BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
281
- return KernelCreateInfo( \
282
- builder.SetName(#name) \
283
- .SetDomain(domain) \
284
- .SinceVersion(ver) \
285
- .Provider(provider) \
286
- .Build(), \
287
- static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
288
- }
289
-
290
- #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
291
- provider##_##name##_##domain##_ver##startver##_##endver##_##type
292
-
293
- #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
294
- ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
295
- __VA_ARGS__)
296
-
297
- #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
298
- ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
299
- __VA_ARGS__)
300
-
301
- #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
302
- ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
303
- __VA_ARGS__)
304
-
305
- #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
306
- class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
307
- template <> \
308
- KernelCreateInfo \
309
- BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
310
- type, name)>() { \
311
- return KernelCreateInfo( \
312
- builder.SetName(#name) \
313
- .SetDomain(domain) \
314
- .SinceVersion(startver, endver) \
315
- .Provider(provider) \
316
- .Build(), \
317
- static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
318
- }
319
-
320
- #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
321
- provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
322
-
323
- #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
324
- provider, builder, ...) \
325
- class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
326
- template <> \
327
- KernelCreateInfo \
328
- BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
329
- type1, type2, name)>() { \
330
- return KernelCreateInfo( \
331
- builder.SetName(#name) \
332
- .SetDomain(domain) \
333
- .SinceVersion(startver, endver) \
334
- .Provider(provider) \
335
- .Build(), \
336
- static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
337
- }
338
-
339
- template <typename... Types>
340
- struct BuildKernelDefConstraintsImpl {
341
- std::vector<MLDataType> operator()() const {
342
- return {DataTypeImpl::GetTensorType<Types>()...};
343
- }
344
- };
345
-
346
- #if !defined(DISABLE_SPARSE_TENSORS)
347
- template <typename... Types>
348
- struct BuildKernelDefSparseConstraintsImpl {
349
- std::vector<MLDataType> operator()() const {
350
- return {DataTypeImpl::GetSparseTensorType<Types>()...};
351
- }
352
- };
353
- #endif
354
-
355
- // Use within macro definitions to create a custom vector of constraints.
356
- // Example: #define REG_KERNEL(OP, VERSION, KERNEL_CLASS, Type, ...)
357
- // .TypeConstraint("T", BuildKernelDefConstraints<Type, __VA_ARGS_>())
358
- template <typename... Types>
359
- inline std::vector<MLDataType> BuildKernelDefConstraints() {
360
- return BuildKernelDefConstraintsImpl<Types...>{}();
361
- }
362
-
363
- #if !defined(DISABLE_SPARSE_TENSORS)
364
- template <typename... Types>
365
- inline std::vector<MLDataType> BuildKernelDefSparseConstraints() {
366
- return BuildKernelDefSparseConstraintsImpl<Types...>{}();
367
- }
368
- #endif
369
-
370
- // version of BuildKernelDefConstraints() which takes a type list
371
- template <typename L>
372
- inline std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
373
- return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
374
- }
375
-
376
- #if !defined(DISABLE_SPARSE_TENSORS)
377
- template <typename L>
378
- inline std::vector<MLDataType> BuildKernelDefSparseConstraintsFromTypeList() {
379
- return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
380
- }
381
- #endif
382
-
383
- } // namespace onnxruntime
384
-
385
- #ifndef SHARED_PROVIDER
386
- #include "core/framework/op_kernel_context.h"
387
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/op_kernel_context.h DELETED
@@ -1,237 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- namespace onnxruntime {
5
- class IExecutionFrame;
6
- class Stream;
7
- namespace concurrency {
8
- class ThreadPool;
9
- }
10
-
11
- class OpKernelContext {
12
- public:
13
- using ArgMap = std::unordered_map<std::string, size_t>;
14
-
15
- OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
16
- _In_ Stream* stream,
17
- _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
18
-
19
- virtual ~OpKernelContext() = default;
20
-
21
- /**
22
- Return the number of inputs for a variadic argument.
23
- @param arg_num The operator argument number.
24
- @returns Number of inputs the argument has.
25
- */
26
- virtual int NumVariadicInputs(size_t arg_num) const;
27
-
28
- virtual MLDataType InputType(int index) const;
29
- virtual MLDataType OutputType(int index) const;
30
-
31
- const OrtValue* GetInputOrtValue(int index) const {
32
- return GetInputMLValue(index);
33
- }
34
-
35
- template <typename T>
36
- const T* Input(int index) const {
37
- const OrtValue* p_ml_value = GetInputMLValue(index);
38
- ORT_TRY {
39
- return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
40
- }
41
- ORT_CATCH(const std::exception& /*e*/) {
42
- ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
43
- }
44
- }
45
-
46
- // Fetch a required input, enforcing that it is present.
47
- template <typename T>
48
- const T& RequiredInput(int index) const {
49
- const T* input_ptr = Input<T>(index);
50
- ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
51
- return *input_ptr;
52
- }
53
-
54
- // Fetch output (non-tensor) with specified index.
55
- template <typename T>
56
- T* Output(int index) {
57
- if (index < 0 || index >= OutputCount())
58
- return nullptr;
59
-
60
- OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
61
- return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
62
- }
63
-
64
- // In the case that memory allocation has not been done for an output tensor,
65
- // The memory allocation will be done on-the-fly with given tensor shape.
66
- // Return nullptr if the output is an unused optional output.
67
- Tensor* Output(int index, const TensorShape& shape);
68
- Tensor* Output(int index, const std::vector<int64_t>& shape);
69
- Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
70
-
71
- // Fetch a required tensor output, enforcing that it is present.
72
- Tensor& RequiredOutput(int index, const TensorShape& shape) {
73
- Tensor* output_ptr = Output(index, shape);
74
- ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
75
- return *output_ptr;
76
- }
77
-
78
- #if !defined(DISABLE_SPARSE_TENSORS)
79
- // Fetch a sparse-tensor output corresponding to the specified index.
80
- // shape must specify the shape of the underlying dense-tensor.
81
- // Memory allocation for the output may happen when this method is invoked,
82
- // unless static optimization pre-allocates it.
83
- SparseTensor* OutputSparse(int index, const TensorShape& shape);
84
- #endif
85
-
86
- #if !defined(DISABLE_OPTIONAL_TYPE)
87
- // Use this API to output a "None" of a specific type (e.g. Tensor) at specified index
88
- template <typename T>
89
- void OutputOptionalWithoutData(int index) {
90
- auto* output_ort_value = GetOutputMLValue(index);
91
-
92
- auto type = DataTypeImpl::GetType<T>();
93
-
94
- output_ort_value->Init(nullptr, // This OrtValue is "None" and has no data
95
- type,
96
- type->GetDeleteFunc());
97
- }
98
- #endif
99
-
100
- // Retrieve indexed shape obtained from memory planning before actual
101
- // computation. If the indexed shape cannot be inferred, this function returns
102
- // false.
103
- virtual bool TryGetInferredInputShape(int index, TensorShape& shape) const;
104
-
105
- // Retrieve indexed shape obtained from memory planning before actual
106
- // computation. If the indexed shape cannot be inferred, this function returns
107
- // false.
108
- virtual bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
109
-
110
- const logging::Logger& Logger() const {
111
- return *logger_;
112
- }
113
-
114
- // always >= 0
115
- virtual int InputCount() const {
116
- return static_cast<int>(kernel_->Node().InputDefs().size());
117
- }
118
-
119
- // always >= 0
120
- virtual int ImplicitInputCount() const {
121
- return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
122
- }
123
-
124
- // always >= 0
125
- virtual int OutputCount() const {
126
- return static_cast<int>(kernel_->Node().OutputDefs().size());
127
- }
128
-
129
- /**
130
- Return an allocator on device 0, with memtype of OrtMemTypeDefault.
131
- @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
132
- */
133
- [[nodiscard]] virtual Status GetTempSpaceAllocator(AllocatorPtr* output) const;
134
-
135
- /**
136
- Return the allocator associated with the CPU EP with memtype of OrtMemTypeDefault.
137
- @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
138
- */
139
- [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const;
140
-
141
- /**
142
- Return the device id that current kernel runs on.
143
- */
144
- virtual int GetDeviceId() const {
145
- return kernel_->Info().GetExecutionProvider()->GetDeviceId();
146
- }
147
-
148
- /**
149
- Return the compute stream associated with the EP that the kernel is partitioned to.
150
- For EPs that do not have a compute stream (e.g. CPU EP), a nullptr is returned.
151
- */
152
- [[nodiscard]] virtual Stream* GetComputeStream() const {
153
- return stream_;
154
- }
155
-
156
- /**
157
- Returns the opset domain of the underlying kernel
158
- **/
159
- const std::string& GetOpDomain() const;
160
-
161
- /**
162
- Returns the optype of the underlying kernel
163
- **/
164
- const std::string& GetOpType() const;
165
-
166
- /**
167
- Returns the node name of the underlying kernel
168
- **/
169
- const std::string& GetNodeName() const;
170
-
171
- /**
172
- Returns the intra-op threadpool, if available.
173
- */
174
- _Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
175
-
176
- /**
177
- Returns whether deterministic computation is preferred.
178
- */
179
- virtual bool GetUseDeterministicCompute() const {
180
- return true;
181
- }
182
-
183
- protected:
184
- OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream);
185
-
186
- onnxruntime::NodeIndex GetNodeIndex() const;
187
-
188
- virtual const OrtValue* GetInputMLValue(int index) const;
189
- const OrtValue* GetImplicitInputMLValue(int index) const;
190
- OrtValue* GetOutputMLValue(int index);
191
-
192
- #ifdef ENABLE_ATEN
193
- Status SetOutputMLValue(int index, const OrtValue& ort_value);
194
- #endif
195
-
196
- // Creates the OrtValue* based on the shape, if it does not exist
197
- virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);
198
-
199
- virtual OrtValue* GetOrCreateOutputMLValue(int index);
200
-
201
- private:
202
- ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
203
- int GetInputArgIndex(int index) const;
204
- int GetImplicitInputArgIndex(int index) const;
205
- int GetOutputArgIndex(int index) const;
206
-
207
- IExecutionFrame* const execution_frame_{};
208
- const OpKernel* const kernel_{};
209
- concurrency::ThreadPool* const threadpool_{};
210
- const logging::Logger* const logger_{};
211
-
212
- // The argument starting index in ExecutionFrame.
213
- int node_input_start_index_{-1};
214
- int node_implicit_input_start_index_{-1};
215
- int node_output_start_index_{-1};
216
-
217
- Stream* stream_;
218
- };
219
-
220
- // Fetching output tensor without shape is not allowed except when it already exists
221
- template <>
222
- inline Tensor* OpKernelContext::Output<Tensor>(int index) {
223
- OrtValue* p_ml_value = GetOutputMLValue(index);
224
- ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
225
- return p_ml_value->GetMutable<Tensor>();
226
- }
227
-
228
- #if !defined(DISABLE_SPARSE_TENSORS)
229
- template <>
230
- inline SparseTensor* OpKernelContext::Output<SparseTensor>(int index) {
231
- OrtValue* p_ml_value = GetOutputMLValue(index);
232
- ORT_ENFORCE(p_ml_value, "Please fetch output sparse tensor with specified shape.");
233
- return p_ml_value->GetMutable<SparseTensor>();
234
- }
235
- #endif
236
-
237
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/op_kernel_info.h DELETED
@@ -1,63 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include "core/framework/execution_provider.h"
7
- #include "core/framework/kernel_def_builder.h"
8
- #include "core/framework/ort_value.h"
9
- #include "core/framework/op_node_proto_helper.h"
10
- #include "core/graph/graph_viewer.h"
11
- #include "core/common/gsl.h"
12
-
13
- namespace onnxruntime {
14
-
15
- class OrtValueNameIdxMap;
16
- class FuncManager;
17
- class DataTransferManager;
18
- struct AllocPlanPerValue;
19
-
20
- // A very light-weight class, which works as an aggregated
21
- // view of all data needed for constructing a Kernel instance.
22
- // NOTE: it does not own/hold any objects.
23
- class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
24
- public:
25
- explicit OpKernelInfo(const onnxruntime::Node& node,
26
- const KernelDef& kernel_def,
27
- const IExecutionProvider& execution_provider,
28
- const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
29
- const OrtValueNameIdxMap& mlvalue_name_idx_map,
30
- const DataTransferManager& data_transfer_mgr);
31
-
32
- OpKernelInfo(const OpKernelInfo& other);
33
-
34
- const OrtMemoryInfo& GetMemoryInfo(OrtMemType mem_type) const;
35
-
36
- AllocatorPtr GetAllocator(OrtMemType mem_type) const;
37
-
38
- const KernelDef& GetKernelDef() const;
39
-
40
- const IExecutionProvider* GetExecutionProvider() const noexcept;
41
-
42
- const DataTransferManager& GetDataTransferManager() const noexcept;
43
-
44
- const onnxruntime::Node& node() const noexcept;
45
-
46
- bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
47
-
48
- private:
49
- ORT_DISALLOW_MOVE(OpKernelInfo);
50
- ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
51
-
52
- const onnxruntime::Node& node_;
53
- const KernelDef& kernel_def_;
54
- // For non cpu/cuda case, this pointer should be set so that function kernel
55
- // will delegate kernel compute call to <execution_provider> compute call.
56
- gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
57
- const std::unordered_map<int, OrtValue>& constant_initialized_tensors_;
58
- const OrtValueNameIdxMap& ort_value_name_idx_map_;
59
- const DataTransferManager& data_transfer_mgr_;
60
- ProtoHelperNodeContext proto_helper_context_;
61
- };
62
-
63
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/op_node_proto_helper.h DELETED
@@ -1,167 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #ifndef SHARED_PROVIDER
7
- #include "core/common/status.h"
8
- #include "core/framework/tensor_shape.h"
9
- #include "core/graph/graph_viewer.h"
10
- #include "core/common/gsl.h"
11
- #endif
12
-
13
- #ifdef __has_attribute
14
- #define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
15
- #else
16
- #define ORT_HAVE_ATTRIBUTE(x) 0
17
- #endif
18
-
19
- #if ORT_HAVE_ATTRIBUTE(nodiscard)
20
- #define MUST_USE_RESULT [[nodiscard]]
21
- #elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result)
22
- #define MUST_USE_RESULT __attribute__((warn_unused_result))
23
- #else
24
- #define MUST_USE_RESULT
25
- #endif
26
-
27
- class IMLOpKernel;
28
-
29
- namespace onnxruntime {
30
-
31
- /**
32
- A set of wrappers with common signatures for use with both OpKernelInfo
33
- (as its base class) and InferenceContext. Used by ABI kernels for both
34
- shape / type inference and kernel construction
35
- */
36
- template <class Impl_t>
37
- class OpNodeProtoHelper {
38
- public:
39
- explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {}
40
-
41
- /**
42
- Get a single attribute
43
- Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
44
- */
45
- template <typename T>
46
- MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const;
47
-
48
- /**
49
- Get a single attribute
50
- Call this function only when a default value for an optional attribute isn't specified in the op schema
51
- */
52
- template <typename T>
53
- T GetAttrOrDefault(const std::string& name, const T& default_value) const {
54
- T tmp;
55
- return GetAttr<T>(name, &tmp).IsOK() ? tmp : default_value;
56
- }
57
-
58
- /**
59
- Get a single attribute
60
- Call this function only when a default value for an optional attribute isn't specified in the op schema
61
- */
62
- template <typename T>
63
- void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const {
64
- if (!GetAttr<T>(name, value).IsOK())
65
- *value = default_value;
66
- }
67
-
68
- /**
69
- Get repeated attributes
70
- Call this function only when a default value for an optional attribute isn't specified in the op schema
71
- */
72
- template <typename T>
73
- MUST_USE_RESULT std::vector<T> GetAttrsOrDefault(const std::string& name, const std::vector<T>& default_value = std::vector<T>{}) const {
74
- std::vector<T> tmp;
75
- return GetAttrs<T>(name, tmp).IsOK() ? tmp : default_value;
76
- }
77
-
78
- /// <summary>
79
- /// Return a gsl::span that points to an array of primitive types held by AttributeProto
80
- /// This function allows to avoid copying big attributes locally into a kernel and operate on
81
- /// AttributeProto data directly.
82
- ///
83
- /// Does not apply to strings, Tensors and Sparse Tensors that require special treatment.
84
- /// </summary>
85
- /// <typeparam name="T">Primitive type contained in the array</typeparam>
86
- /// <param name="name">Attribute name</param>
87
- /// <param name="values">Attribute data in a span, out parameter</param>
88
- /// <returns>Status</returns>
89
- template <typename T>
90
- MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;
91
-
92
- MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const;
93
-
94
- MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const {
95
- TensorShapeVector tmp;
96
- return GetAttrs(name, tmp).IsOK() ? tmp : default_value;
97
- }
98
-
99
- /**
100
- Get repeated attributes
101
- */
102
- template <typename T>
103
- MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector<T>& values) const;
104
-
105
- template <typename T>
106
- MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span<T> values) const;
107
-
108
- MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name,
109
- std::vector<std::reference_wrapper<const std::string>>& refs) const;
110
-
111
- uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
112
- const std::string& name) const noexcept;
113
-
114
- bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
115
- const std::string& name) const noexcept;
116
-
117
- uint32_t GetInputCount() const {
118
- return gsl::narrow_cast<uint32_t>(impl_->getNumInputs());
119
- }
120
-
121
- uint32_t GetOutputCount() const {
122
- return gsl::narrow_cast<uint32_t>(impl_->getNumOutputs());
123
- }
124
-
125
- const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
126
- return impl_->getInputType(index);
127
- }
128
-
129
- const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
130
- // Work around lack of a const method from the onnx InferenceContext interface
131
- return const_cast<Impl_t*>(impl_)->getOutputType(index);
132
- }
133
-
134
- // Try to query an attribute, returning nullptr if it doesn't exist
135
- const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
136
- return impl_->getAttribute(name);
137
- }
138
-
139
- const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
140
- const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name);
141
- ORT_ENFORCE(attr != nullptr);
142
- return attr;
143
- }
144
-
145
- private:
146
- OpNodeProtoHelper() = delete;
147
- const Impl_t* impl_ = nullptr;
148
- };
149
-
150
- // The methods on the following class are called by OpNodeProtoHelper, implementing
151
- // the same signatures as InferenceContext other than const-ness.
152
- class ProtoHelperNodeContext {
153
- public:
154
- explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {}
155
- ProtoHelperNodeContext() = delete;
156
-
157
- const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const;
158
- size_t getNumInputs() const;
159
- const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const;
160
- size_t getNumOutputs() const;
161
- const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const;
162
-
163
- private:
164
- const onnxruntime::Node& node_;
165
- };
166
-
167
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/ort_value.h DELETED
@@ -1,123 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string>
7
- #ifndef SHARED_PROVIDER
8
- #include "core/common/common.h"
9
- #include "core/common/exceptions.h"
10
- #include "core/framework/allocator.h"
11
- #include "core/framework/data_types.h"
12
- #include "core/framework/tensor.h"
13
-
14
- namespace onnxruntime {
15
- #if !defined(DISABLE_SPARSE_TENSORS)
16
- class SparseTensor;
17
- #endif
18
- class TensorSeq;
19
- } // namespace onnxruntime
20
-
21
- #endif
22
-
23
- /**
24
- Represents both tensors and non-tensors.
25
- */
26
- struct OrtValue {
27
- public:
28
- OrtValue() : data_(nullptr) {}
29
- ~OrtValue() = default;
30
-
31
- OrtValue(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) {
32
- Init(pData, type, deleter);
33
- }
34
-
35
- void Init(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) {
36
- data_.reset(pData, deleter);
37
- type_ = type;
38
- }
39
-
40
- void Init(void* pData, onnxruntime::MLDataType type, const std::function<void(void*)>& deleter) {
41
- data_.reset(pData, deleter);
42
- type_ = type;
43
- }
44
-
45
- bool IsAllocated() const {
46
- return data_ && type_;
47
- }
48
-
49
- template <typename T>
50
- const T& Get() const {
51
- ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
52
- return *static_cast<T*>(data_.get());
53
- }
54
-
55
- // May return nullptr, if this OrtValue is an optional type and it is "None".
56
- template <typename T>
57
- T* GetMutable() {
58
- ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
59
- return static_cast<T*>(data_.get());
60
- }
61
-
62
- bool IsTensor() const noexcept {
63
- return (type_ != nullptr && type_->IsTensorType());
64
- }
65
-
66
- bool IsTensorSequence() const noexcept {
67
- return (type_ != nullptr && type_->IsTensorSequenceType());
68
- }
69
-
70
- bool IsSparseTensor() const {
71
- #if !defined(DISABLE_SPARSE_TENSORS)
72
- return (type_ != nullptr && type_->IsSparseTensorType());
73
- #else
74
- ORT_THROW("Sparse tensor is not supported in this build.");
75
- #endif
76
- }
77
-
78
- onnxruntime::MLDataType Type() const {
79
- return type_;
80
- }
81
-
82
- private:
83
- std::shared_ptr<void> data_;
84
- onnxruntime::MLDataType type_{nullptr};
85
- };
86
-
87
- template <>
88
- inline const onnxruntime::Tensor& OrtValue::Get<onnxruntime::Tensor>() const {
89
- ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
90
- return *static_cast<onnxruntime::Tensor*>(data_.get());
91
- }
92
-
93
- template <>
94
- inline onnxruntime::Tensor* OrtValue::GetMutable<onnxruntime::Tensor>() {
95
- ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
96
- return static_cast<onnxruntime::Tensor*>(data_.get());
97
- }
98
-
99
- template <>
100
- inline const onnxruntime::TensorSeq& OrtValue::Get<onnxruntime::TensorSeq>() const {
101
- ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
102
- return *static_cast<onnxruntime::TensorSeq*>(data_.get());
103
- }
104
-
105
- template <>
106
- inline onnxruntime::TensorSeq* OrtValue::GetMutable<onnxruntime::TensorSeq>() {
107
- ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
108
- return static_cast<onnxruntime::TensorSeq*>(data_.get());
109
- }
110
-
111
- #if !defined(DISABLE_SPARSE_TENSORS)
112
- template <>
113
- inline const onnxruntime::SparseTensor& OrtValue::Get<onnxruntime::SparseTensor>() const {
114
- ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
115
- return *static_cast<onnxruntime::SparseTensor*>(data_.get());
116
- }
117
-
118
- template <>
119
- inline onnxruntime::SparseTensor* OrtValue::GetMutable<onnxruntime::SparseTensor>() {
120
- ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
121
- return static_cast<onnxruntime::SparseTensor*>(data_.get());
122
- }
123
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/ortdevice.h DELETED
@@ -1,74 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <sstream>
7
-
8
- // Struct to represent a physical device.
9
- struct OrtDevice {
10
- using DeviceType = int8_t;
11
- using MemoryType = int8_t;
12
- using DeviceId = int16_t;
13
-
14
- // Pre-defined device types.
15
- static const DeviceType CPU = 0;
16
- static const DeviceType GPU = 1; // Nvidia or AMD
17
- static const DeviceType FPGA = 2;
18
- static const DeviceType NPU = 3; // Ascend
19
-
20
- struct MemType {
21
- // Pre-defined memory types.
22
- static const MemoryType DEFAULT = 0;
23
- static const MemoryType CUDA_PINNED = 1;
24
- static const MemoryType HIP_PINNED = 2;
25
- static const MemoryType CANN_PINNED = 3;
26
- };
27
-
28
- constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
29
- : device_type(device_type_),
30
- memory_type(memory_type_),
31
- device_id(device_id_) {}
32
-
33
- constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
34
-
35
- DeviceType Type() const {
36
- return device_type;
37
- }
38
-
39
- MemoryType MemType() const {
40
- return memory_type;
41
- }
42
-
43
- DeviceId Id() const {
44
- return device_id;
45
- }
46
-
47
- std::string ToString() const {
48
- std::ostringstream ostr;
49
- ostr << "Device:["
50
- << "DeviceType:" << static_cast<int>(device_type)
51
- << " MemoryType:" << static_cast<int>(memory_type)
52
- << " DeviceId:" << device_id
53
- << "]";
54
- return ostr.str();
55
- }
56
-
57
- private:
58
- // Device type.
59
- DeviceType device_type;
60
-
61
- // Memory type.
62
- MemoryType memory_type;
63
-
64
- // Device index.
65
- DeviceId device_id;
66
- };
67
-
68
- inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
69
- return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
70
- }
71
-
72
- inline bool operator!=(const OrtDevice& left, const OrtDevice& other) {
73
- return !(left == other);
74
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/ortmemoryinfo.h DELETED
@@ -1,87 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string_view>
7
-
8
- #include "core/common/hash_combine.h"
9
-
10
- struct OrtMemoryInfo {
11
- OrtMemoryInfo() = default; // to allow default construction of Tensor
12
-
13
- // use string for name, so we could have customized allocator in execution provider.
14
- const char* name = nullptr;
15
- int id = -1;
16
- OrtMemType mem_type = OrtMemTypeDefault;
17
- OrtAllocatorType alloc_type = OrtInvalidAllocator;
18
- OrtDevice device;
19
-
20
- constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0,
21
- OrtMemType mem_type_ = OrtMemTypeDefault)
22
- #if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__))
23
- // this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5
24
- __attribute__((nonnull))
25
- #endif
26
- : name(name_),
27
- id(id_),
28
- mem_type(mem_type_),
29
- alloc_type(type_),
30
- device(device_) {
31
- }
32
-
33
- // To make OrtMemoryInfo become a valid key in std map
34
- bool operator<(const OrtMemoryInfo& other) const {
35
- if (alloc_type != other.alloc_type)
36
- return alloc_type < other.alloc_type;
37
- if (mem_type != other.mem_type)
38
- return mem_type < other.mem_type;
39
- if (id != other.id)
40
- return id < other.id;
41
-
42
- return strcmp(name, other.name) < 0;
43
- }
44
-
45
- // This is to make OrtMemoryInfo a valid key in hash tables
46
- // we ignore device id
47
- size_t Hash() const {
48
- auto h = std::hash<int>()(alloc_type);
49
- onnxruntime::HashCombine(mem_type, h);
50
- onnxruntime::HashCombine(id, h);
51
- onnxruntime::HashCombine<std::string_view>(name, h);
52
- return h;
53
- }
54
-
55
- std::string ToString() const {
56
- std::ostringstream ostr;
57
- ostr << "OrtMemoryInfo:["
58
- << "name:" << name
59
- << " id:" << id
60
- << " OrtMemType:" << mem_type
61
- << " OrtAllocatorType:" << alloc_type
62
- << " " << device.ToString()
63
- << "]";
64
- return ostr.str();
65
- }
66
- };
67
-
68
- // Required by hash tables
69
- inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
70
- return left.mem_type == other.mem_type &&
71
- left.alloc_type == other.alloc_type &&
72
- left.id == other.id &&
73
- strcmp(left.name, other.name) == 0;
74
- }
75
-
76
- inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); }
77
-
78
- std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info);
79
-
80
- namespace std {
81
- template<>
82
- struct hash<OrtMemoryInfo> {
83
- size_t operator()(const OrtMemoryInfo& i) const {
84
- return i.Hash();
85
- }
86
- };
87
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/provider_options.h DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string>
7
- #include <unordered_map>
8
- #include <vector>
9
-
10
- namespace onnxruntime {
11
-
12
- // data types for execution provider options
13
-
14
- using ProviderOptions = std::unordered_map<std::string, std::string>;
15
- using ProviderOptionsVector = std::vector<ProviderOptions>;
16
- using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
17
-
18
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/provider_options_utils.h DELETED
@@ -1,164 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <algorithm>
7
- #include <functional>
8
- #include <type_traits>
9
- #include <unordered_map>
10
- #include <vector>
11
-
12
- #include "core/common/common.h"
13
- #include "core/common/parse_string.h"
14
- #include "core/framework/provider_options.h"
15
-
16
- namespace onnxruntime {
17
-
18
- template <typename TEnum>
19
- using EnumNameMapping = std::vector<std::pair<TEnum, std::string>>;
20
-
21
- /**
22
- * Given a mapping and an enumeration value, gets the corresponding name.
23
- */
24
- template <typename TEnum>
25
- Status EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value, std::string& name) {
26
- const auto it = std::find_if(
27
- mapping.begin(), mapping.end(),
28
- [&value](const std::pair<TEnum, std::string>& entry) {
29
- return entry.first == value;
30
- });
31
- ORT_RETURN_IF(
32
- it == mapping.end(),
33
- "Failed to map enum value to name: ", static_cast<typename std::underlying_type<TEnum>::type>(value));
34
- name = it->second;
35
- return Status::OK();
36
- }
37
-
38
- template <typename TEnum>
39
- std::string EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value) {
40
- std::string name;
41
- ORT_THROW_IF_ERROR(EnumToName(mapping, value, name));
42
- return name;
43
- }
44
-
45
- /**
46
- * Given a mapping and a name, gets the corresponding enumeration value.
47
- */
48
- template <typename TEnum>
49
- Status NameToEnum(
50
- const EnumNameMapping<TEnum>& mapping, const std::string& name, TEnum& value) {
51
- const auto it = std::find_if(
52
- mapping.begin(), mapping.end(),
53
- [&name](const std::pair<TEnum, std::string>& entry) {
54
- return entry.second == name;
55
- });
56
- ORT_RETURN_IF(
57
- it == mapping.end(),
58
- "Failed to map enum name to value: ", name);
59
- value = it->first;
60
- return Status::OK();
61
- }
62
-
63
- template <typename TEnum>
64
- TEnum NameToEnum(const EnumNameMapping<TEnum>& mapping, const std::string& name) {
65
- TEnum value;
66
- ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value));
67
- return value;
68
- }
69
-
70
- class ProviderOptionsParser {
71
- public:
72
- /**
73
- * Adds a parser for a particular provider option value.
74
- *
75
- * @param name The provider option name.
76
- * @param value_parser An object that parses the option value.
77
- * It should be callable with the following signature and return
78
- * whether the parsing was successful:
79
- * Status value_parser(const std::string&)
80
- *
81
- * @return The current ProviderOptionsParser instance.
82
- */
83
- template <typename ValueParserType>
84
- ProviderOptionsParser& AddValueParser(
85
- const std::string& name, ValueParserType value_parser) {
86
- ORT_ENFORCE(
87
- value_parsers_.emplace(name, ValueParser{value_parser}).second,
88
- "Provider option \"", name, "\" already has a value parser.");
89
- return *this;
90
- }
91
-
92
- /**
93
- * Adds a parser for a particular provider option value which converts a
94
- * value to the right type and assigns it to the given reference.
95
- *
96
- * IMPORTANT: This function stores a reference to the destination variable.
97
- * The caller must ensure that the reference is valid when Parse() is called!
98
- *
99
- * @param name The provider option name.
100
- * @param dest The destination variable reference.
101
- *
102
- * @return The current ProviderOptionsParser instance.
103
- */
104
- template <typename ValueType>
105
- ProviderOptionsParser& AddAssignmentToReference(
106
- const std::string& name, ValueType& dest) {
107
- return AddValueParser(
108
- name,
109
- [&dest](const std::string& value_str) -> Status {
110
- return ParseStringWithClassicLocale(value_str, dest);
111
- });
112
- }
113
-
114
- /**
115
- * Adds a parser for a particular provider option value which maps an
116
- * enumeration name to a value and assigns it to the given reference.
117
- *
118
- * IMPORTANT: This function stores references to the mapping and destination
119
- * variables. The caller must ensure that the references are valid when
120
- * Parse() is called!
121
- *
122
- * @param name The provider option name.
123
- * @param mapping The enumeration value to name mapping.
124
- * @param dest The destination variable reference.
125
- *
126
- * @return The current ProviderOptionsParser instance.
127
- */
128
- template <typename EnumType>
129
- ProviderOptionsParser& AddAssignmentToEnumReference(
130
- const std::string& name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
131
- return AddValueParser(
132
- name,
133
- [&mapping, &dest](const std::string& value_str) -> Status {
134
- return NameToEnum(mapping, value_str, dest);
135
- });
136
- }
137
-
138
- /**
139
- * Parses the given provider options.
140
- */
141
- Status Parse(const ProviderOptions& options) const {
142
- for (const auto& option : options) {
143
- const auto& name = option.first;
144
- const auto& value_str = option.second;
145
- const auto value_parser_it = value_parsers_.find(name);
146
- ORT_RETURN_IF(
147
- value_parser_it == value_parsers_.end(),
148
- "Unknown provider option: \"", name, "\".");
149
-
150
- const auto parse_status = value_parser_it->second(value_str);
151
- ORT_RETURN_IF_NOT(
152
- parse_status.IsOK(),
153
- "Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage());
154
- }
155
-
156
- return Status::OK();
157
- }
158
-
159
- private:
160
- using ValueParser = std::function<Status(const std::string&)>;
161
- std::unordered_map<std::string, ValueParser> value_parsers_;
162
- };
163
-
164
- } // namespace onnxruntime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/provider_shutdown.h DELETED
@@ -1,8 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- namespace onnxruntime {
7
- void UnloadSharedProviders();
8
- }
 
 
 
 
 
 
 
 
 
arm64-v8a/include/onnxruntime/core/framework/run_options.h DELETED
@@ -1,49 +0,0 @@
1
- // Copyright (c) Microsoft Corporation. All rights reserved.
2
- // Licensed under the MIT License.
3
-
4
- #pragma once
5
-
6
- #include <string>
7
- #include <atomic>
8
- #include "core/session/onnxruntime_c_api.h"
9
- #include "core/framework/config_options.h"
10
-
11
- /**
12
- * Configuration information for a Run call.
13
- */
14
- struct OrtRunOptions {
15
- /// Log severity. See https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/common/logging/severity.h
16
- /// Default = -1 (use the log severity from the InferenceSession that the Run is for).
17
- int run_log_severity_level = -1;
18
- int run_log_verbosity_level = 0; ///< VLOG level if debug build and run_log_severity_level is 0 (VERBOSE).
19
- std::string run_tag; ///< A tag for the Run() calls using this.
20
-
21
- // Set to 'true' to ensure the termination of all the outstanding Run() calls
22
- // that use this OrtRunOptions instance. Some of the outstanding Run() calls may
23
- // be forced to terminate with an error status.
24
- bool terminate = false;
25
-
26
- // Set to 'true' to run only the nodes from feeds to required fetches.
27
- // So it is possible that only some of the nodes are executed.
28
- bool only_execute_path_to_fetches = false;
29
-
30
- #ifdef ENABLE_TRAINING
31
- // Used by onnxruntime::training::TrainingSession. This class is now deprecated.
32
- // Delete training_mode when TrainingSession is deleted.
33
- // Set to 'true' to run in training mode.
34
- bool training_mode = true;
35
- #endif
36
-
37
- // Stores the configurations for this run
38
- // To add an configuration to this specific run, call OrtApis::AddRunConfigEntry
39
- // The configuration keys and value formats are defined in
40
- // /include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
41
- onnxruntime::ConfigOptions config_options;
42
-
43
- OrtRunOptions() = default;
44
- ~OrtRunOptions() = default;
45
- };
46
-
47
- namespace onnxruntime {
48
- using RunOptions = OrtRunOptions;
49
- } // namespace onnxruntime