Commit ·
b974d63
1
Parent(s): 0a9fe68
remove extra files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- arm64-v8a/include/onnxruntime/core/common/basic_types.h +0 -19
- arm64-v8a/include/onnxruntime/core/common/code_location.h +0 -58
- arm64-v8a/include/onnxruntime/core/common/common.h +0 -287
- arm64-v8a/include/onnxruntime/core/common/const_pointer_container.h +0 -85
- arm64-v8a/include/onnxruntime/core/common/denormal.h +0 -12
- arm64-v8a/include/onnxruntime/core/common/eigen_common_wrapper.h +0 -62
- arm64-v8a/include/onnxruntime/core/common/exceptions.h +0 -71
- arm64-v8a/include/onnxruntime/core/common/gpu_profiler_common.h +0 -472
- arm64-v8a/include/onnxruntime/core/common/gsl.h +0 -6
- arm64-v8a/include/onnxruntime/core/common/hash_combine.h +0 -21
- arm64-v8a/include/onnxruntime/core/common/inlined_containers.h +0 -175
- arm64-v8a/include/onnxruntime/core/common/inlined_containers_fwd.h +0 -147
- arm64-v8a/include/onnxruntime/core/common/logging/capture.h +0 -115
- arm64-v8a/include/onnxruntime/core/common/logging/isink.h +0 -41
- arm64-v8a/include/onnxruntime/core/common/logging/logging.h +0 -337
- arm64-v8a/include/onnxruntime/core/common/logging/macros.h +0 -278
- arm64-v8a/include/onnxruntime/core/common/logging/severity.h +0 -22
- arm64-v8a/include/onnxruntime/core/common/make_string.h +0 -126
- arm64-v8a/include/onnxruntime/core/common/narrow.h +0 -77
- arm64-v8a/include/onnxruntime/core/common/optional.h +0 -23
- arm64-v8a/include/onnxruntime/core/common/parse_string.h +0 -85
- arm64-v8a/include/onnxruntime/core/common/profiler_common.h +0 -93
- arm64-v8a/include/onnxruntime/core/common/span_utils.h +0 -88
- arm64-v8a/include/onnxruntime/core/common/spin_pause.h +0 -28
- arm64-v8a/include/onnxruntime/core/common/status.h +0 -195
- arm64-v8a/include/onnxruntime/core/common/string_helper.h +0 -11
- arm64-v8a/include/onnxruntime/core/framework/alloc_kind.h +0 -36
- arm64-v8a/include/onnxruntime/core/framework/allocator.h +0 -194
- arm64-v8a/include/onnxruntime/core/framework/buffer_deleter.h +0 -36
- arm64-v8a/include/onnxruntime/core/framework/customregistry.h +0 -60
- arm64-v8a/include/onnxruntime/core/framework/data_types.h +0 -1062
- arm64-v8a/include/onnxruntime/core/framework/data_types_internal.h +0 -569
- arm64-v8a/include/onnxruntime/core/framework/endian.h +0 -27
- arm64-v8a/include/onnxruntime/core/framework/execution_provider.h +0 -340
- arm64-v8a/include/onnxruntime/core/framework/float16.h +0 -159
- arm64-v8a/include/onnxruntime/core/framework/framework_common.h +0 -22
- arm64-v8a/include/onnxruntime/core/framework/func_api.h +0 -27
- arm64-v8a/include/onnxruntime/core/framework/kernel_def_builder.h +0 -353
- arm64-v8a/include/onnxruntime/core/framework/kernel_registry.h +0 -91
- arm64-v8a/include/onnxruntime/core/framework/op_kernel.h +0 -387
- arm64-v8a/include/onnxruntime/core/framework/op_kernel_context.h +0 -237
- arm64-v8a/include/onnxruntime/core/framework/op_kernel_info.h +0 -63
- arm64-v8a/include/onnxruntime/core/framework/op_node_proto_helper.h +0 -167
- arm64-v8a/include/onnxruntime/core/framework/ort_value.h +0 -123
- arm64-v8a/include/onnxruntime/core/framework/ortdevice.h +0 -74
- arm64-v8a/include/onnxruntime/core/framework/ortmemoryinfo.h +0 -87
- arm64-v8a/include/onnxruntime/core/framework/provider_options.h +0 -18
- arm64-v8a/include/onnxruntime/core/framework/provider_options_utils.h +0 -164
- arm64-v8a/include/onnxruntime/core/framework/provider_shutdown.h +0 -8
- arm64-v8a/include/onnxruntime/core/framework/run_options.h +0 -49
arm64-v8a/include/onnxruntime/core/common/basic_types.h
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <cstdint>
|
| 7 |
-
|
| 8 |
-
namespace onnxruntime {
|
| 9 |
-
|
| 10 |
-
/** A computed hash value. */
|
| 11 |
-
using HashValue = uint64_t;
|
| 12 |
-
|
| 13 |
-
/** The type of an argument (input or output).*/
|
| 14 |
-
enum class ArgType : uint8_t {
|
| 15 |
-
kInput,
|
| 16 |
-
kOutput,
|
| 17 |
-
};
|
| 18 |
-
|
| 19 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/code_location.h
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <sstream>
|
| 7 |
-
#include <string>
|
| 8 |
-
#include <vector>
|
| 9 |
-
|
| 10 |
-
namespace onnxruntime {
|
| 11 |
-
/**
|
| 12 |
-
CodeLocation captures information on where in the source code a message came from.
|
| 13 |
-
*/
|
| 14 |
-
struct CodeLocation {
|
| 15 |
-
/**
|
| 16 |
-
@param file_path Usually the value of __FILE__
|
| 17 |
-
@param line Usually the value of __LINE__
|
| 18 |
-
@param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__
|
| 19 |
-
*/
|
| 20 |
-
CodeLocation(const char* file_path, const int line, const char* func)
|
| 21 |
-
: file_and_path{file_path}, line_num{line}, function{func} {
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
/**
|
| 25 |
-
@param file_path Usually the value of __FILE__
|
| 26 |
-
@param line Usually the value of __LINE__
|
| 27 |
-
@param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__
|
| 28 |
-
@param stacktrace Stacktrace from source of message.
|
| 29 |
-
*/
|
| 30 |
-
CodeLocation(const char* file_path, const int line, const char* func, const std::vector<std::string>& stacktrace)
|
| 31 |
-
: file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) {
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
std::string FileNoPath() const {
|
| 35 |
-
// assuming we always have work to do, so not trying to avoid creating a new string if
|
| 36 |
-
// no path was removed.
|
| 37 |
-
return file_and_path.substr(file_and_path.find_last_of("/\\") + 1);
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
enum Format {
|
| 41 |
-
kFilename,
|
| 42 |
-
kFilenameAndPath
|
| 43 |
-
};
|
| 44 |
-
|
| 45 |
-
std::string ToString(Format format = Format::kFilename) const {
|
| 46 |
-
std::ostringstream out;
|
| 47 |
-
out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function;
|
| 48 |
-
return out.str();
|
| 49 |
-
}
|
| 50 |
-
//utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
|
| 51 |
-
const std::string file_and_path;
|
| 52 |
-
const int line_num;
|
| 53 |
-
//utf-8
|
| 54 |
-
const std::string function;
|
| 55 |
-
const std::vector<std::string> stacktrace;
|
| 56 |
-
};
|
| 57 |
-
|
| 58 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/common.h
DELETED
|
@@ -1,287 +0,0 @@
|
|
| 1 |
-
/**
|
| 2 |
-
* Copyright (c) 2016-present, Facebook, Inc.
|
| 3 |
-
*
|
| 4 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
* you may not use this file except in compliance with the License.
|
| 6 |
-
* You may obtain a copy of the License at
|
| 7 |
-
*
|
| 8 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
*
|
| 10 |
-
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
* See the License for the specific language governing permissions and
|
| 14 |
-
* limitations under the License.
|
| 15 |
-
*/
|
| 16 |
-
// Portions Copyright (c) Microsoft Corporation
|
| 17 |
-
|
| 18 |
-
#pragma once
|
| 19 |
-
|
| 20 |
-
#include <climits>
|
| 21 |
-
#include <cstring>
|
| 22 |
-
#include <algorithm>
|
| 23 |
-
#include <chrono>
|
| 24 |
-
#include <functional>
|
| 25 |
-
#include <memory>
|
| 26 |
-
#include <numeric>
|
| 27 |
-
#include <set>
|
| 28 |
-
#include <sstream>
|
| 29 |
-
#include <string>
|
| 30 |
-
#include <type_traits>
|
| 31 |
-
#include <unordered_map>
|
| 32 |
-
#include <utility>
|
| 33 |
-
#include <vector>
|
| 34 |
-
|
| 35 |
-
#include "core/common/code_location.h"
|
| 36 |
-
#include "core/common/exceptions.h"
|
| 37 |
-
#include "core/common/make_string.h"
|
| 38 |
-
#include "core/common/status.h"
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
namespace onnxruntime {
|
| 42 |
-
|
| 43 |
-
using TimePoint = std::chrono::high_resolution_clock::time_point;
|
| 44 |
-
|
| 45 |
-
#ifdef _WIN32
|
| 46 |
-
#define ORT_UNUSED_PARAMETER(x) (x)
|
| 47 |
-
#else
|
| 48 |
-
#define ORT_UNUSED_PARAMETER(x) (void)(x)
|
| 49 |
-
#endif
|
| 50 |
-
|
| 51 |
-
#ifndef ORT_HAVE_ATTRIBUTE
|
| 52 |
-
#ifdef __has_attribute
|
| 53 |
-
#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
|
| 54 |
-
#else
|
| 55 |
-
#define ORT_HAVE_ATTRIBUTE(x) 0
|
| 56 |
-
#endif
|
| 57 |
-
#endif
|
| 58 |
-
|
| 59 |
-
// ORT_ATTRIBUTE_UNUSED
|
| 60 |
-
//
|
| 61 |
-
// Prevents the compiler from complaining about or optimizing away variables
|
| 62 |
-
// that appear unused on Linux
|
| 63 |
-
#if ORT_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__))
|
| 64 |
-
#undef ORT_ATTRIBUTE_UNUSED
|
| 65 |
-
#define ORT_ATTRIBUTE_UNUSED __attribute__((__unused__))
|
| 66 |
-
#else
|
| 67 |
-
#define ORT_ATTRIBUTE_UNUSED
|
| 68 |
-
#endif
|
| 69 |
-
|
| 70 |
-
#ifdef ORT_NO_EXCEPTIONS
|
| 71 |
-
// Print the given final message, the message must be a null terminated char*
|
| 72 |
-
// ORT will abort after printing the message.
|
| 73 |
-
// For Android, will print to Android system log
|
| 74 |
-
// For other platforms, will print to stderr
|
| 75 |
-
void PrintFinalMessage(const char* msg);
|
| 76 |
-
#endif
|
| 77 |
-
|
| 78 |
-
// macro to explicitly ignore the return value from a function call so Code Analysis doesn't complain
|
| 79 |
-
#define ORT_IGNORE_RETURN_VALUE(fn) \
|
| 80 |
-
static_cast<void>(fn)
|
| 81 |
-
|
| 82 |
-
std::vector<std::string> GetStackTrace();
|
| 83 |
-
// these is a helper function that gets defined by platform/Telemetry
|
| 84 |
-
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
|
| 85 |
-
const char* function, uint32_t line);
|
| 86 |
-
|
| 87 |
-
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
|
| 88 |
-
// so we only define it as one for MSVC
|
| 89 |
-
#if (_MSC_VER && !defined(__PRETTY_FUNCTION__))
|
| 90 |
-
#define __PRETTY_FUNCTION__ __FUNCTION__
|
| 91 |
-
#endif
|
| 92 |
-
|
| 93 |
-
// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__
|
| 94 |
-
#define ORT_WHERE ::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast<const char*>(__FUNCTION__))
|
| 95 |
-
|
| 96 |
-
#define ORT_WHERE_WITH_STACK \
|
| 97 |
-
::onnxruntime::CodeLocation(__FILE__, __LINE__, static_cast<const char*>(__PRETTY_FUNCTION__), ::onnxruntime::GetStackTrace())
|
| 98 |
-
|
| 99 |
-
#ifdef ORT_NO_EXCEPTIONS
|
| 100 |
-
|
| 101 |
-
#define ORT_TRY if (true)
|
| 102 |
-
#define ORT_CATCH(x) else if (false)
|
| 103 |
-
#define ORT_RETHROW
|
| 104 |
-
|
| 105 |
-
// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
|
| 106 |
-
// in the body of the catch statements, it is necessary to wrap the body of the catch statement into
|
| 107 |
-
// a lambda function. otherwise the exception referred will be undefined and cause build break
|
| 108 |
-
#define ORT_HANDLE_EXCEPTION(func)
|
| 109 |
-
|
| 110 |
-
// Throw an exception with optional message.
|
| 111 |
-
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
| 112 |
-
// DO NOT use a printf format string, as that will not work as you expect.
|
| 113 |
-
#define ORT_THROW(...) \
|
| 114 |
-
do { \
|
| 115 |
-
::onnxruntime::PrintFinalMessage( \
|
| 116 |
-
::onnxruntime::OnnxRuntimeException( \
|
| 117 |
-
ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) \
|
| 118 |
-
.what()); \
|
| 119 |
-
abort(); \
|
| 120 |
-
} while (false)
|
| 121 |
-
|
| 122 |
-
// Just in order to mark things as not implemented. Do not use in final code.
|
| 123 |
-
#define ORT_NOT_IMPLEMENTED(...) \
|
| 124 |
-
do { \
|
| 125 |
-
::onnxruntime::PrintFinalMessage( \
|
| 126 |
-
::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) \
|
| 127 |
-
.what()); \
|
| 128 |
-
abort(); \
|
| 129 |
-
} while (false)
|
| 130 |
-
|
| 131 |
-
// Check condition.
|
| 132 |
-
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
| 133 |
-
// DO NOT use a printf format string, as that will not work as you expect.
|
| 134 |
-
#define ORT_ENFORCE(condition, ...) \
|
| 135 |
-
do { \
|
| 136 |
-
if (!(condition)) { \
|
| 137 |
-
::onnxruntime::PrintFinalMessage( \
|
| 138 |
-
::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \
|
| 139 |
-
::onnxruntime::MakeString(__VA_ARGS__)) \
|
| 140 |
-
.what()); \
|
| 141 |
-
abort(); \
|
| 142 |
-
} \
|
| 143 |
-
} while (false)
|
| 144 |
-
|
| 145 |
-
#define ORT_THROW_EX(ex, ...) \
|
| 146 |
-
do { \
|
| 147 |
-
::onnxruntime::PrintFinalMessage( \
|
| 148 |
-
::onnxruntime::MakeString(#ex, "(", ::onnxruntime::MakeString(__VA_ARGS__), ")").c_str()); \
|
| 149 |
-
abort(); \
|
| 150 |
-
} while (false)
|
| 151 |
-
|
| 152 |
-
#else
|
| 153 |
-
|
| 154 |
-
#define ORT_TRY try
|
| 155 |
-
#define ORT_CATCH(x) catch (x)
|
| 156 |
-
#define ORT_RETHROW throw;
|
| 157 |
-
|
| 158 |
-
#define ORT_HANDLE_EXCEPTION(func) func()
|
| 159 |
-
|
| 160 |
-
// Throw an exception with optional message.
|
| 161 |
-
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
| 162 |
-
// DO NOT use a printf format string, as that will not work as you expect.
|
| 163 |
-
#define ORT_THROW(...) \
|
| 164 |
-
throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__))
|
| 165 |
-
|
| 166 |
-
// Just in order to mark things as not implemented. Do not use in final code.
|
| 167 |
-
#define ORT_NOT_IMPLEMENTED(...) \
|
| 168 |
-
throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__))
|
| 169 |
-
|
| 170 |
-
// Check condition.
|
| 171 |
-
// NOTE: The arguments get streamed into a string via ostringstream::operator<<
|
| 172 |
-
// DO NOT use a printf format string, as that will not work as you expect.
|
| 173 |
-
#define ORT_ENFORCE(condition, ...) \
|
| 174 |
-
do { \
|
| 175 |
-
if (!(condition)) { \
|
| 176 |
-
throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \
|
| 177 |
-
::onnxruntime::MakeString(__VA_ARGS__)); \
|
| 178 |
-
} \
|
| 179 |
-
} while (false)
|
| 180 |
-
|
| 181 |
-
#define ORT_THROW_EX(ex, ...) \
|
| 182 |
-
throw ex(__VA_ARGS__)
|
| 183 |
-
|
| 184 |
-
#endif
|
| 185 |
-
|
| 186 |
-
#define ORT_MAKE_STATUS(category, code, ...) \
|
| 187 |
-
::onnxruntime::common::Status(::onnxruntime::common::category, \
|
| 188 |
-
::onnxruntime::common::code, \
|
| 189 |
-
::onnxruntime::MakeString(__VA_ARGS__))
|
| 190 |
-
|
| 191 |
-
// Check condition. if met, return status.
|
| 192 |
-
#define ORT_RETURN_IF(condition, ...) \
|
| 193 |
-
do { \
|
| 194 |
-
if (condition) { \
|
| 195 |
-
return ::onnxruntime::common::Status(::onnxruntime::common::ONNXRUNTIME, \
|
| 196 |
-
::onnxruntime::common::FAIL, \
|
| 197 |
-
::onnxruntime::MakeString(ORT_WHERE.ToString(), " ", __VA_ARGS__)); \
|
| 198 |
-
} \
|
| 199 |
-
} while (false)
|
| 200 |
-
|
| 201 |
-
// Check condition. if not met, return status.
|
| 202 |
-
#define ORT_RETURN_IF_NOT(condition, ...) \
|
| 203 |
-
ORT_RETURN_IF(!(condition), __VA_ARGS__)
|
| 204 |
-
|
| 205 |
-
// Macros to disable the copy and/or move ctor and assignment methods
|
| 206 |
-
// These are usually placed in the private: declarations for a class.
|
| 207 |
-
|
| 208 |
-
#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete
|
| 209 |
-
|
| 210 |
-
#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete
|
| 211 |
-
|
| 212 |
-
#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \
|
| 213 |
-
ORT_DISALLOW_COPY(TypeName); \
|
| 214 |
-
ORT_DISALLOW_ASSIGNMENT(TypeName)
|
| 215 |
-
|
| 216 |
-
#define ORT_DISALLOW_MOVE(TypeName) \
|
| 217 |
-
TypeName(TypeName&&) = delete; \
|
| 218 |
-
TypeName& operator=(TypeName&&) = delete
|
| 219 |
-
|
| 220 |
-
#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \
|
| 221 |
-
ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \
|
| 222 |
-
ORT_DISALLOW_MOVE(TypeName)
|
| 223 |
-
|
| 224 |
-
#define ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id) \
|
| 225 |
-
do { \
|
| 226 |
-
auto _status = (expr); \
|
| 227 |
-
if ((!_status.IsOK())) { \
|
| 228 |
-
::onnxruntime::LogRuntimeError(session_id, _status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__); \
|
| 229 |
-
return _status; \
|
| 230 |
-
} \
|
| 231 |
-
} while (0)
|
| 232 |
-
|
| 233 |
-
#define ORT_RETURN_IF_ERROR_SESSIONID_(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id_)
|
| 234 |
-
#define ORT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, 0)
|
| 235 |
-
|
| 236 |
-
#define ORT_THROW_IF_ERROR(expr) \
|
| 237 |
-
do { \
|
| 238 |
-
auto _status = (expr); \
|
| 239 |
-
if ((!_status.IsOK())) { \
|
| 240 |
-
::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__); \
|
| 241 |
-
ORT_THROW(_status); \
|
| 242 |
-
} \
|
| 243 |
-
} while (0)
|
| 244 |
-
|
| 245 |
-
// use this macro when cannot early return
|
| 246 |
-
#define ORT_CHECK_AND_SET_RETVAL(expr) \
|
| 247 |
-
do { \
|
| 248 |
-
if (retval.IsOK()) { \
|
| 249 |
-
retval = (expr); \
|
| 250 |
-
} \
|
| 251 |
-
} while (0)
|
| 252 |
-
|
| 253 |
-
inline long long TimeDiffMicroSeconds(TimePoint start_time) {
|
| 254 |
-
auto end_time = std::chrono::high_resolution_clock::now();
|
| 255 |
-
return std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
inline long long TimeDiffMicroSeconds(TimePoint start_time, TimePoint end_time) {
|
| 259 |
-
return std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
| 260 |
-
}
|
| 261 |
-
|
| 262 |
-
struct null_type {};
|
| 263 |
-
inline std::string ToUTF8String(const std::string& s) { return s; }
|
| 264 |
-
#ifdef _WIN32
|
| 265 |
-
/**
|
| 266 |
-
* Convert a wide character string to a UTF-8 string
|
| 267 |
-
*/
|
| 268 |
-
std::string ToUTF8String(const std::wstring& s);
|
| 269 |
-
|
| 270 |
-
std::wstring ToWideString(const std::string& s);
|
| 271 |
-
inline std::wstring ToWideString(const std::wstring& s) { return s; }
|
| 272 |
-
#else
|
| 273 |
-
inline std::string ToWideString(const std::string& s) { return s; }
|
| 274 |
-
#endif
|
| 275 |
-
|
| 276 |
-
constexpr size_t kMaxStrLen = 2048;
|
| 277 |
-
|
| 278 |
-
// Returns whether `key` is in `container`.
|
| 279 |
-
// Like C++20's map/set contains() member function.
|
| 280 |
-
template <typename Key, typename... OtherContainerArgs,
|
| 281 |
-
template <typename...> typename AssociativeContainer,
|
| 282 |
-
typename LookupKey>
|
| 283 |
-
inline bool Contains(const AssociativeContainer<Key, OtherContainerArgs...>& container, LookupKey&& key) {
|
| 284 |
-
return container.find(std::forward<LookupKey>(key)) != container.end();
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/const_pointer_container.h
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <type_traits>
|
| 7 |
-
|
| 8 |
-
namespace onnxruntime {
|
| 9 |
-
/**
|
| 10 |
-
Container has T* entries. e.g. std::vector<T*>, and this class provides const access to those
|
| 11 |
-
via iterators and direct access, as the standard behavior only makes the pointer constant,
|
| 12 |
-
and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper.
|
| 13 |
-
See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers
|
| 14 |
-
*/
|
| 15 |
-
template <typename Container>
|
| 16 |
-
class ConstPointerContainer {
|
| 17 |
-
public:
|
| 18 |
-
using T = typename std::remove_pointer<typename Container::value_type>::type;
|
| 19 |
-
|
| 20 |
-
class ConstIterator {
|
| 21 |
-
public:
|
| 22 |
-
using const_iterator = typename Container::const_iterator;
|
| 23 |
-
using iterator_category = std::input_iterator_tag;
|
| 24 |
-
using value_type = T*;
|
| 25 |
-
using difference_type = std::ptrdiff_t;
|
| 26 |
-
using pointer = T**;
|
| 27 |
-
using reference = T*&;
|
| 28 |
-
|
| 29 |
-
/** Construct iterator for container that will return const T* entries.*/
|
| 30 |
-
explicit ConstIterator(const_iterator position) noexcept : current_{position}, item_{nullptr} {}
|
| 31 |
-
ConstIterator(const ConstIterator& other) = default;
|
| 32 |
-
ConstIterator& operator=(const ConstIterator& other) = default;
|
| 33 |
-
|
| 34 |
-
bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; }
|
| 35 |
-
bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; }
|
| 36 |
-
|
| 37 |
-
ConstIterator& operator++() {
|
| 38 |
-
++current_;
|
| 39 |
-
return *this;
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
ConstIterator operator++(int) {
|
| 43 |
-
ConstIterator tmp{*this};
|
| 44 |
-
++(*this);
|
| 45 |
-
return tmp;
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
const T*& operator*() const {
|
| 49 |
-
item_ = *current_;
|
| 50 |
-
return item_;
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
const T** operator->() const { return &(operator*()); };
|
| 54 |
-
|
| 55 |
-
private:
|
| 56 |
-
const_iterator current_;
|
| 57 |
-
mutable const T* item_;
|
| 58 |
-
};
|
| 59 |
-
|
| 60 |
-
/**
|
| 61 |
-
Construct wrapper class that will provide const access to the pointers in a container of non-const pointers.
|
| 62 |
-
@param data Container with non-const pointers. e.g. std::vector<T*>
|
| 63 |
-
*/
|
| 64 |
-
explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {}
|
| 65 |
-
|
| 66 |
-
size_t size() const noexcept { return data_.size(); }
|
| 67 |
-
bool empty() const noexcept { return data_.empty(); }
|
| 68 |
-
|
| 69 |
-
ConstIterator cbegin() const noexcept { return ConstIterator(data_.cbegin()); }
|
| 70 |
-
ConstIterator cend() const noexcept { return ConstIterator(data_.cend()); }
|
| 71 |
-
|
| 72 |
-
ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); }
|
| 73 |
-
ConstIterator end() const noexcept { return ConstIterator(data_.cend()); }
|
| 74 |
-
|
| 75 |
-
const T* operator[](size_t index) const { return data_[index]; }
|
| 76 |
-
|
| 77 |
-
const T* at(size_t index) const {
|
| 78 |
-
ORT_ENFORCE(index < data_.size());
|
| 79 |
-
return data_[index];
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
private:
|
| 83 |
-
const Container& data_;
|
| 84 |
-
};
|
| 85 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/denormal.h
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
namespace onnxruntime {
|
| 7 |
-
|
| 8 |
-
// Set or unset flush-to-zero and denormal=as-zero if SSE3 instructions are supported.
|
| 9 |
-
// Return true if SSE3 instruction is supported, otherwise return false.
|
| 10 |
-
bool SetDenormalAsZero(bool on);
|
| 11 |
-
|
| 12 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/eigen_common_wrapper.h
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
//-----------------------------------------------------------------------------
|
| 2 |
-
//
|
| 3 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 4 |
-
// Licensed under the MIT License.
|
| 5 |
-
//
|
| 6 |
-
//-----------------------------------------------------------------------------
|
| 7 |
-
#pragma once
|
| 8 |
-
#include "onnxruntime_config.h"
|
| 9 |
-
// build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71:
|
| 10 |
-
// error: ignoring attributes on template argument "Eigen::PacketType<const float, Eigen::DefaultDevice>::type {aka __vector(4) float}" [-Werror=ignored-attributes]
|
| 11 |
-
#if defined(__GNUC__)
|
| 12 |
-
#pragma GCC diagnostic push
|
| 13 |
-
#if __GNUC__ >= 6
|
| 14 |
-
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
| 15 |
-
#endif
|
| 16 |
-
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
| 17 |
-
#pragma GCC diagnostic ignored "-Wunused-result"
|
| 18 |
-
#ifdef HAS_DEPRECATED_COPY
|
| 19 |
-
#pragma GCC diagnostic ignored "-Wdeprecated-copy"
|
| 20 |
-
#endif
|
| 21 |
-
// cmake/external/eigen/unsupported/Eigen/CXX11/../../../Eigen/src/Core/arch/NEON/PacketMath.h:1633:9:
|
| 22 |
-
// error: ‘void* memcpy(void*, const void*, size_t)’ copying an object of non-trivial type ‘Eigen::internal::Packet4c’
|
| 23 |
-
// {aka ‘struct Eigen::internal::eigen_packet_wrapper<int, 2>’} from an array of ‘const int8_t’
|
| 24 |
-
// {aka ‘const signed char’} [-Werror=class-memaccess]
|
| 25 |
-
#ifdef HAS_CLASS_MEMACCESS
|
| 26 |
-
#pragma GCC diagnostic ignored "-Wclass-memaccess"
|
| 27 |
-
#endif
|
| 28 |
-
|
| 29 |
-
// cmake/external/eigen\Eigen/src/Core/util/Meta.h:454:25:
|
| 30 |
-
// error: 'result_of<Eigen::internal::scalar_product_op<unsigned long long> (const unsigned long long &, const unsigned long long &)>'
|
| 31 |
-
// is deprecated [-Werror,-Wdeprecated-declarations]
|
| 32 |
-
// typedef typename std::result_of<T>::type type1;
|
| 33 |
-
#ifdef HAS_DEPRECATED_DECLARATIONS
|
| 34 |
-
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
| 35 |
-
#endif
|
| 36 |
-
|
| 37 |
-
// cmake/external/eigen\Eigen/CXX11/src/Tensor/TensorTrace.h:130:9:
|
| 38 |
-
// error: variable 'num_distinct_reduce_dims' set but not used [-Werror,-Wunused-but-set-variable]
|
| 39 |
-
// int num_distinct_reduce_dims = 0;
|
| 40 |
-
#ifdef HAS_UNUSED_BUT_SET_VARIABLE
|
| 41 |
-
#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
|
| 42 |
-
#endif
|
| 43 |
-
|
| 44 |
-
#elif defined(_MSC_VER)
|
| 45 |
-
// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76):
|
| 46 |
-
// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence
|
| 47 |
-
|
| 48 |
-
// unsupported\eigen\cxx11\src\Tensor\TensorUInt128.h(150,0): Warning C4245: 'initializing': conversion from '__int64'
|
| 49 |
-
// to 'uint64_t', signed/unsigned mismatch
|
| 50 |
-
#pragma warning(push)
|
| 51 |
-
#pragma warning(disable : 4554)
|
| 52 |
-
#pragma warning(disable : 4245)
|
| 53 |
-
#pragma warning(disable : 4127)
|
| 54 |
-
#endif
|
| 55 |
-
|
| 56 |
-
#include "unsupported/Eigen/CXX11/Tensor"
|
| 57 |
-
|
| 58 |
-
#if defined(__GNUC__)
|
| 59 |
-
#pragma GCC diagnostic pop
|
| 60 |
-
#elif defined(_MSC_VER)
|
| 61 |
-
#pragma warning(pop)
|
| 62 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/exceptions.h
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <algorithm>
|
| 7 |
-
#include <exception>
|
| 8 |
-
#include <iterator>
|
| 9 |
-
#include <stdexcept>
|
| 10 |
-
#include <string>
|
| 11 |
-
#include <vector>
|
| 12 |
-
|
| 13 |
-
#include "core/common/common.h"
|
| 14 |
-
#include "core/common/code_location.h"
|
| 15 |
-
|
| 16 |
-
namespace onnxruntime {
|
| 17 |
-
|
| 18 |
-
class NotImplementedException : public std::logic_error {
|
| 19 |
-
public:
|
| 20 |
-
explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){};
|
| 21 |
-
explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){};
|
| 22 |
-
};
|
| 23 |
-
|
| 24 |
-
class TypeMismatchException : public std::logic_error {
|
| 25 |
-
public:
|
| 26 |
-
TypeMismatchException() noexcept : logic_error("Type mismatch"){};
|
| 27 |
-
};
|
| 28 |
-
|
| 29 |
-
class OnnxRuntimeException : public std::exception {
|
| 30 |
-
public:
|
| 31 |
-
OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept
|
| 32 |
-
: OnnxRuntimeException(location, nullptr, msg) {
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
/**
|
| 36 |
-
Create a new exception that captures the location it was thrown from.
|
| 37 |
-
@param location Location in the source code the exception is being thrown from
|
| 38 |
-
@param failed_condition Optional string containing the condition that failed.
|
| 39 |
-
e.g. "tensor.Size() == input.Size()". May be nullptr.
|
| 40 |
-
@param msg Message containing additional information about the exception cause.
|
| 41 |
-
*/
|
| 42 |
-
OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg)
|
| 43 |
-
: location_{location} {
|
| 44 |
-
std::ostringstream ss;
|
| 45 |
-
|
| 46 |
-
ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous
|
| 47 |
-
if (failed_condition != nullptr) {
|
| 48 |
-
ss << " " << failed_condition << " was false.";
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
ss << " " << msg << "\n";
|
| 52 |
-
if (!location.stacktrace.empty()) {
|
| 53 |
-
ss << "Stacktrace:\n";
|
| 54 |
-
// skip the first entry in the stacktrace as we have that information from location.ToString()
|
| 55 |
-
std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator<std::string>(ss, "\n"));
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
what_ = ss.str();
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
const char* what() const noexcept override {
|
| 62 |
-
return what_.c_str();
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
private:
|
| 66 |
-
const CodeLocation location_;
|
| 67 |
-
const std::vector<std::string> stacktrace_;
|
| 68 |
-
std::string what_;
|
| 69 |
-
};
|
| 70 |
-
|
| 71 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/gpu_profiler_common.h
DELETED
|
@@ -1,472 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
|
| 3 |
-
#include "core/common/profiler_common.h"
|
| 4 |
-
#include "core/common/inlined_containers.h"
|
| 5 |
-
|
| 6 |
-
#include <map>
|
| 7 |
-
#include <memory>
|
| 8 |
-
#include <mutex>
|
| 9 |
-
#include <sstream>
|
| 10 |
-
#include <string>
|
| 11 |
-
#include <vector>
|
| 12 |
-
#include <utility>
|
| 13 |
-
|
| 14 |
-
namespace onnxruntime {
|
| 15 |
-
namespace profiling {
|
| 16 |
-
|
| 17 |
-
// The classes in this header are implemented as template/inline classes
|
| 18 |
-
// to avoid having to export symbols from the main onnxruntime shared library
|
| 19 |
-
// to ExecutionProvider (EP) shared libraries.
|
| 20 |
-
// More context: The main onnxruntime shared library is optimized for size
|
| 21 |
-
// using --gc-sections during link time to ensure that any unreferenced code
|
| 22 |
-
// is not retained. This poses a problem in using a design pattern where the
|
| 23 |
-
// (abstract) base class is implemented in the main onnxruntime shared library,
|
| 24 |
-
// but (concrete) subclasses are implemented in EP shared libraries. Now, because
|
| 25 |
-
// EP shared libraries are loaded at runtime (as of 11/2022), there will be no
|
| 26 |
-
// references to the base class symbols when the main onnxruntime shared library
|
| 27 |
-
// is compiled. Thus, the base class symbols will not be included in the
|
| 28 |
-
// main onnxruntime shared library. This manifests in being unable to load
|
| 29 |
-
// EP shared libs (because the base class symbols referenced by derived
|
| 30 |
-
// classes are missing).
|
| 31 |
-
// We solve this by implementing base classes that are common to all GPU profilers
|
| 32 |
-
// inline in this header.
|
| 33 |
-
|
| 34 |
-
class ProfilerActivityBuffer {
|
| 35 |
-
public:
|
| 36 |
-
ProfilerActivityBuffer() noexcept
|
| 37 |
-
: data_(nullptr), size_(0) {}
|
| 38 |
-
|
| 39 |
-
ProfilerActivityBuffer(const char* data, size_t size) noexcept
|
| 40 |
-
: data_(std::make_unique<char[]>(size)), size_(size) {
|
| 41 |
-
memcpy(data_.get(), data, size_);
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
ProfilerActivityBuffer(const ProfilerActivityBuffer& other) noexcept
|
| 45 |
-
: ProfilerActivityBuffer(other.GetData(), other.GetSize()) {}
|
| 46 |
-
|
| 47 |
-
ProfilerActivityBuffer(ProfilerActivityBuffer&& other) noexcept
|
| 48 |
-
: ProfilerActivityBuffer() {
|
| 49 |
-
std::swap(data_, other.data_);
|
| 50 |
-
std::swap(size_, other.size_);
|
| 51 |
-
}
|
| 52 |
-
|
| 53 |
-
ProfilerActivityBuffer& operator=(const ProfilerActivityBuffer& other) noexcept {
|
| 54 |
-
if (&other == this) {
|
| 55 |
-
return *this;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
new (this) ProfilerActivityBuffer{other};
|
| 59 |
-
return *this;
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
ProfilerActivityBuffer& operator=(ProfilerActivityBuffer&& other) noexcept {
|
| 63 |
-
if (&other == this) {
|
| 64 |
-
return *this;
|
| 65 |
-
}
|
| 66 |
-
|
| 67 |
-
new (this) ProfilerActivityBuffer{std::move(other)};
|
| 68 |
-
return *this;
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
static ProfilerActivityBuffer CreateFromPreallocatedBuffer(std::unique_ptr<char[]>&& buffer_ptr, size_t size) {
|
| 72 |
-
ProfilerActivityBuffer res{};
|
| 73 |
-
res.data_ = std::move(buffer_ptr);
|
| 74 |
-
res.size_ = size;
|
| 75 |
-
return res;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
// accessors
|
| 79 |
-
char* GetData() { return data_.get(); }
|
| 80 |
-
const char* GetData() const { return data_.get(); }
|
| 81 |
-
size_t GetSize() const { return size_; }
|
| 82 |
-
|
| 83 |
-
private:
|
| 84 |
-
std::unique_ptr<char[]> data_;
|
| 85 |
-
size_t size_;
|
| 86 |
-
}; /* end class ProfilerActivityBuffer */
|
| 87 |
-
|
| 88 |
-
template <typename TDerived>
|
| 89 |
-
class GPUTracerManager {
|
| 90 |
-
public:
|
| 91 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GPUTracerManager);
|
| 92 |
-
virtual ~GPUTracerManager() {}
|
| 93 |
-
|
| 94 |
-
uint64_t RegisterClient() {
|
| 95 |
-
std::lock_guard<std::mutex> lock(manager_instance_mutex_);
|
| 96 |
-
auto res = next_client_id_++;
|
| 97 |
-
per_client_events_by_ext_correlation_.insert({res, {}});
|
| 98 |
-
++num_active_clients_;
|
| 99 |
-
return res;
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
void DeregisterClient(uint64_t client_handle) {
|
| 103 |
-
std::lock_guard<std::mutex> lock(manager_instance_mutex_);
|
| 104 |
-
auto it = per_client_events_by_ext_correlation_.find(client_handle);
|
| 105 |
-
if (it == per_client_events_by_ext_correlation_.end()) {
|
| 106 |
-
return;
|
| 107 |
-
}
|
| 108 |
-
per_client_events_by_ext_correlation_.erase(it);
|
| 109 |
-
--num_active_clients_;
|
| 110 |
-
if (num_active_clients_ == 0 && tracing_enabled_) {
|
| 111 |
-
StopLogging();
|
| 112 |
-
}
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
void StartLogging() {
|
| 116 |
-
std::lock_guard<std::mutex> lock(manager_instance_mutex_);
|
| 117 |
-
if (tracing_enabled_) {
|
| 118 |
-
return;
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 122 |
-
tracing_enabled_ = this_as_derived->OnStartLogging();
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
void Consume(uint64_t client_handle, const TimePoint& start_time, std::map<uint64_t, Events>& events) {
|
| 126 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 127 |
-
events.clear();
|
| 128 |
-
{
|
| 129 |
-
// Flush any pending activity records before starting
|
| 130 |
-
// to process the accumulated activity records.
|
| 131 |
-
std::lock_guard<std::mutex> lock_manager(manager_instance_mutex_);
|
| 132 |
-
if (!tracing_enabled_) {
|
| 133 |
-
return;
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
this_as_derived->FlushActivities();
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
std::vector<ProfilerActivityBuffer> activity_buffers;
|
| 140 |
-
{
|
| 141 |
-
std::lock_guard<std::mutex> lock(unprocessed_activity_buffers_mutex_);
|
| 142 |
-
std::swap(unprocessed_activity_buffers_, activity_buffers);
|
| 143 |
-
unprocessed_activity_buffers_.clear();
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
{
|
| 147 |
-
// Ensure that at most one thread is working through the activity buffers at any time.
|
| 148 |
-
std::lock_guard<std::mutex> lock_two(activity_buffer_processor_mutex_);
|
| 149 |
-
this_as_derived->ProcessActivityBuffers(activity_buffers, start_time);
|
| 150 |
-
auto it = per_client_events_by_ext_correlation_.find(client_handle);
|
| 151 |
-
if (it == per_client_events_by_ext_correlation_.end()) {
|
| 152 |
-
return;
|
| 153 |
-
}
|
| 154 |
-
std::swap(events, it->second);
|
| 155 |
-
}
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
void PushCorrelation(uint64_t client_handle,
|
| 159 |
-
uint64_t external_correlation_id,
|
| 160 |
-
TimePoint profiling_start_time) {
|
| 161 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 162 |
-
std::lock_guard<std::mutex> lock(manager_instance_mutex_);
|
| 163 |
-
if (!tracing_enabled_) {
|
| 164 |
-
return;
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
auto it = per_client_events_by_ext_correlation_.find(client_handle);
|
| 168 |
-
if (it == per_client_events_by_ext_correlation_.end()) {
|
| 169 |
-
// not a registered client, do nothing
|
| 170 |
-
return;
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
// external_correlation_id is simply the timestamp of this event,
|
| 174 |
-
// relative to profiling_start_time. i.e., it was computed as:
|
| 175 |
-
// external_correlation_id =
|
| 176 |
-
// std::chrono::duration_cast<std::chrono::microseconds>(event_start_time - profiling_start_time).count()
|
| 177 |
-
//
|
| 178 |
-
// Because of the relative nature of the external_correlation_id, the same
|
| 179 |
-
// external_correlation_id can be reused across different clients, which then makes it
|
| 180 |
-
// impossible to recover the client from the external_correlation_id, which in turn
|
| 181 |
-
// makes it impossible to map events (which are tagged with external_correlation_id) to clients.
|
| 182 |
-
//
|
| 183 |
-
// To address these difficulties, we construct a new correlation_id (let's call it unique_cid)
|
| 184 |
-
// as follows:
|
| 185 |
-
// unique_cid =
|
| 186 |
-
// external_correlation_id +
|
| 187 |
-
// std::chrono::duration_cast<std::chrono::microseconds>(profiling_start_time.time_since_epoch()).count()
|
| 188 |
-
// now, unique_cid is monotonically increasing with time, so it can be used to reliably map events to clients.
|
| 189 |
-
//
|
| 190 |
-
// Of course, clients expect lists of events to be returned (on a call to Consume()), that are
|
| 191 |
-
// still keyed on the external_correlation_id that they've specified here, so we need to remember the
|
| 192 |
-
// offset to be subtracted
|
| 193 |
-
uint64_t offset = std::chrono::duration_cast<std::chrono::microseconds>(profiling_start_time.time_since_epoch()).count();
|
| 194 |
-
auto unique_cid = external_correlation_id + offset;
|
| 195 |
-
unique_correlation_id_to_client_offset_[unique_cid] = std::make_pair(client_handle, offset);
|
| 196 |
-
this_as_derived->PushUniqueCorrelation(unique_cid);
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
void PopCorrelation(uint64_t& popped_external_correlation_id) {
|
| 200 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 201 |
-
std::lock_guard<std::mutex> lock(manager_instance_mutex_);
|
| 202 |
-
if (!tracing_enabled_) {
|
| 203 |
-
return;
|
| 204 |
-
}
|
| 205 |
-
uint64_t unique_cid;
|
| 206 |
-
this_as_derived->PopUniqueCorrelation(unique_cid);
|
| 207 |
-
// lookup the offset and subtract it before returning popped_external_correlation_id to the client
|
| 208 |
-
auto client_it = unique_correlation_id_to_client_offset_.find(unique_cid);
|
| 209 |
-
if (client_it == unique_correlation_id_to_client_offset_.end()) {
|
| 210 |
-
popped_external_correlation_id = 0;
|
| 211 |
-
return;
|
| 212 |
-
}
|
| 213 |
-
popped_external_correlation_id = unique_cid - client_it->second.second;
|
| 214 |
-
}
|
| 215 |
-
|
| 216 |
-
void PopCorrelation() {
|
| 217 |
-
uint64_t unused;
|
| 218 |
-
PopCorrelation(unused);
|
| 219 |
-
}
|
| 220 |
-
|
| 221 |
-
protected:
|
| 222 |
-
GPUTracerManager() {
|
| 223 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 224 |
-
uint64_t gpu_ts1, gpu_ts2, cpu_ts;
|
| 225 |
-
|
| 226 |
-
// Get the CPU and GPU timestamps to warm up
|
| 227 |
-
gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds();
|
| 228 |
-
cpu_ts = this->GetCPUTimestampInNanoseconds();
|
| 229 |
-
|
| 230 |
-
// Estimate the skew/offset between the CPU and GPU timestamps.
|
| 231 |
-
gpu_ts1 = this_as_derived->GetGPUTimestampInNanoseconds();
|
| 232 |
-
cpu_ts = this->GetCPUTimestampInNanoseconds();
|
| 233 |
-
gpu_ts2 = this_as_derived->GetGPUTimestampInNanoseconds();
|
| 234 |
-
|
| 235 |
-
auto gpu_ts = (gpu_ts1 + gpu_ts2) / 2;
|
| 236 |
-
offset_to_add_to_gpu_timestamps_ = cpu_ts - gpu_ts;
|
| 237 |
-
}
|
| 238 |
-
|
| 239 |
-
#if 0
|
| 240 |
-
// Functional API to be implemented by subclasses
|
| 241 |
-
// Included here only for documentation purposes
|
| 242 |
-
protected:
|
| 243 |
-
bool OnStartLogging();
|
| 244 |
-
void OnStopLogging();
|
| 245 |
-
void ProcessActivityBuffers(const std::vector<ProfilerActivityBuffer>& buffers,
|
| 246 |
-
const TimePoint& start_time);
|
| 247 |
-
bool PushUniqueCorrelation(uint64_t unique_cid);
|
| 248 |
-
void PopUniqueCorrelation(uint64_t& popped_unique_cid);
|
| 249 |
-
void FlushActivities();
|
| 250 |
-
uint64_t GetGPUTimestampInNanoseconds();
|
| 251 |
-
#endif
|
| 252 |
-
|
| 253 |
-
void EnqueueActivityBuffer(ProfilerActivityBuffer&& buffer) {
|
| 254 |
-
std::lock_guard<std::mutex> lock(unprocessed_activity_buffers_mutex_);
|
| 255 |
-
unprocessed_activity_buffers_.emplace_back(std::move(buffer));
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
-
// To be called by subclasses only from ProcessActivityBuffers
|
| 259 |
-
void MapEventToClient(uint64_t tracer_correlation_id, EventRecord&& event) {
|
| 260 |
-
auto it = tracer_correlation_to_unique_correlation_.find(tracer_correlation_id);
|
| 261 |
-
if (it == tracer_correlation_to_unique_correlation_.end()) {
|
| 262 |
-
// We're yet to receive a mapping to unique_correlation_id for this tracer_correlation_id
|
| 263 |
-
DeferEventMapping(std::move(event), tracer_correlation_id);
|
| 264 |
-
return;
|
| 265 |
-
}
|
| 266 |
-
auto unique_correlation_id = it->second;
|
| 267 |
-
auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id);
|
| 268 |
-
if (p_event_list != nullptr) {
|
| 269 |
-
p_event_list->emplace_back(std::move(event));
|
| 270 |
-
}
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
// To be called by subclasses only from ProcessActivityBuffers
|
| 274 |
-
void NotifyNewCorrelation(uint64_t tracer_correlation_id, uint64_t unique_correlation_id) {
|
| 275 |
-
tracer_correlation_to_unique_correlation_[tracer_correlation_id] = unique_correlation_id;
|
| 276 |
-
auto pending_it = events_pending_client_mapping_.find(tracer_correlation_id);
|
| 277 |
-
if (pending_it == events_pending_client_mapping_.end()) {
|
| 278 |
-
return;
|
| 279 |
-
}
|
| 280 |
-
// Map the pending events to the right client
|
| 281 |
-
MapEventsToClient(unique_correlation_id, std::move(pending_it->second));
|
| 282 |
-
events_pending_client_mapping_.erase(pending_it);
|
| 283 |
-
}
|
| 284 |
-
|
| 285 |
-
uint64_t NormalizeGPUTimestampToCPUEpoch(uint64_t gpu_timestamp_in_nanoseconds) {
|
| 286 |
-
return gpu_timestamp_in_nanoseconds + this->offset_to_add_to_gpu_timestamps_;
|
| 287 |
-
}
|
| 288 |
-
|
| 289 |
-
private:
|
| 290 |
-
// Requires: manager_instance_mutex_ should be held
|
| 291 |
-
void StopLogging() {
|
| 292 |
-
auto this_as_derived = static_cast<TDerived*>(this);
|
| 293 |
-
if (!tracing_enabled_) {
|
| 294 |
-
return;
|
| 295 |
-
}
|
| 296 |
-
this_as_derived->OnStopLogging();
|
| 297 |
-
tracing_enabled_ = false;
|
| 298 |
-
Clear();
|
| 299 |
-
}
|
| 300 |
-
|
| 301 |
-
// Requires: manager_instance_mutex_ should be held
|
| 302 |
-
void Clear() {
|
| 303 |
-
unprocessed_activity_buffers_.clear();
|
| 304 |
-
unique_correlation_id_to_client_offset_.clear();
|
| 305 |
-
per_client_events_by_ext_correlation_.clear();
|
| 306 |
-
tracer_correlation_to_unique_correlation_.clear();
|
| 307 |
-
events_pending_client_mapping_.clear();
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
Events* GetEventListForUniqueCorrelationId(uint64_t unique_correlation_id) {
|
| 311 |
-
auto client_it = unique_correlation_id_to_client_offset_.find(unique_correlation_id);
|
| 312 |
-
if (client_it == unique_correlation_id_to_client_offset_.end()) {
|
| 313 |
-
return nullptr;
|
| 314 |
-
}
|
| 315 |
-
|
| 316 |
-
// See the comments on the GetUniqueCorrelationId method for an explanation of
|
| 317 |
-
// of this offset computation and why it's required.
|
| 318 |
-
auto const& client_handle_offset = client_it->second;
|
| 319 |
-
auto external_correlation = unique_correlation_id - client_handle_offset.second;
|
| 320 |
-
auto& event_list = per_client_events_by_ext_correlation_[client_handle_offset.first][external_correlation];
|
| 321 |
-
return &event_list;
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
void MapEventsToClient(uint64_t unique_correlation_id, std::vector<EventRecord>&& events) {
|
| 325 |
-
auto p_event_list = GetEventListForUniqueCorrelationId(unique_correlation_id);
|
| 326 |
-
if (p_event_list != nullptr) {
|
| 327 |
-
p_event_list->insert(p_event_list->end(),
|
| 328 |
-
std::make_move_iterator(events.begin()),
|
| 329 |
-
std::make_move_iterator(events.end()));
|
| 330 |
-
}
|
| 331 |
-
}
|
| 332 |
-
|
| 333 |
-
void DeferEventMapping(EventRecord&& event, uint64_t tracer_correlation_id) {
|
| 334 |
-
events_pending_client_mapping_[tracer_correlation_id].emplace_back(std::move(event));
|
| 335 |
-
}
|
| 336 |
-
|
| 337 |
-
uint64_t GetCPUTimestampInNanoseconds() {
|
| 338 |
-
return std::chrono::duration_cast<std::chrono::nanoseconds>(
|
| 339 |
-
std::chrono::high_resolution_clock::now().time_since_epoch())
|
| 340 |
-
.count();
|
| 341 |
-
}
|
| 342 |
-
|
| 343 |
-
std::mutex manager_instance_mutex_;
|
| 344 |
-
uint64_t next_client_id_ = 1;
|
| 345 |
-
uint64_t num_active_clients_ = 0;
|
| 346 |
-
bool tracing_enabled_ = false;
|
| 347 |
-
std::mutex unprocessed_activity_buffers_mutex_;
|
| 348 |
-
std::mutex activity_buffer_processor_mutex_;
|
| 349 |
-
|
| 350 |
-
// Unprocessed activity buffers
|
| 351 |
-
std::vector<ProfilerActivityBuffer> unprocessed_activity_buffers_;
|
| 352 |
-
|
| 353 |
-
// Keyed on unique_correlation_id -> (client_id/client_handle, offset)
|
| 354 |
-
// unique_correlation_id - offset == external_correlation_id
|
| 355 |
-
InlinedHashMap<uint64_t, std::pair<uint64_t, uint64_t>> unique_correlation_id_to_client_offset_;
|
| 356 |
-
|
| 357 |
-
// Keyed on tracer_correlation_id -> unique_correlation_id
|
| 358 |
-
InlinedHashMap<uint64_t, uint64_t> tracer_correlation_to_unique_correlation_;
|
| 359 |
-
|
| 360 |
-
// client_id/client_handle -> external_correlation_id -> events
|
| 361 |
-
InlinedHashMap<uint64_t, std::map<uint64_t, Events>> per_client_events_by_ext_correlation_;
|
| 362 |
-
|
| 363 |
-
// Keyed on tracer correlation_id, keeps track of activity records
|
| 364 |
-
// for which we haven't established the external_correlation_id yet.
|
| 365 |
-
InlinedHashMap<uint64_t, std::vector<EventRecord>> events_pending_client_mapping_;
|
| 366 |
-
|
| 367 |
-
// An offset to add to (the possibly skewed) GPU timestamps
|
| 368 |
-
// to normalize GPU timestamps with CPU timestamps
|
| 369 |
-
int64_t offset_to_add_to_gpu_timestamps_;
|
| 370 |
-
}; /* class GPUTracerManager */
|
| 371 |
-
|
| 372 |
-
// Base class for a GPU profiler
|
| 373 |
-
template <typename TManager>
|
| 374 |
-
class GPUProfilerBase : public EpProfiler {
|
| 375 |
-
protected:
|
| 376 |
-
GPUProfilerBase() = default;
|
| 377 |
-
virtual ~GPUProfilerBase() {}
|
| 378 |
-
|
| 379 |
-
void MergeEvents(std::map<uint64_t, Events>& events_to_merge, Events& events) {
|
| 380 |
-
Events merged_events;
|
| 381 |
-
|
| 382 |
-
auto event_iter = std::make_move_iterator(events.begin());
|
| 383 |
-
auto event_end = std::make_move_iterator(events.end());
|
| 384 |
-
for (auto& map_iter : events_to_merge) {
|
| 385 |
-
if (map_iter.second.empty()) {
|
| 386 |
-
continue;
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
auto ts = static_cast<long long>(map_iter.first);
|
| 390 |
-
|
| 391 |
-
// find the last occurrence of a matching timestamp,
|
| 392 |
-
// if one exists
|
| 393 |
-
while (event_iter != event_end &&
|
| 394 |
-
(event_iter->ts < ts ||
|
| 395 |
-
(event_iter->ts == ts &&
|
| 396 |
-
(event_iter + 1) != event_end &&
|
| 397 |
-
(event_iter + 1)->ts == ts))) {
|
| 398 |
-
merged_events.emplace_back(*event_iter);
|
| 399 |
-
++event_iter;
|
| 400 |
-
}
|
| 401 |
-
|
| 402 |
-
bool copy_op_names = false;
|
| 403 |
-
std::string op_name;
|
| 404 |
-
std::string parent_name;
|
| 405 |
-
|
| 406 |
-
if (event_iter != event_end && event_iter->ts == ts) {
|
| 407 |
-
// We've located a parent event, copy the op_name and set
|
| 408 |
-
// this event's parent_name property to the name of the parent.
|
| 409 |
-
copy_op_names = true;
|
| 410 |
-
op_name = event_iter->args["op_name"];
|
| 411 |
-
parent_name = event_iter->name;
|
| 412 |
-
merged_events.emplace_back(*event_iter);
|
| 413 |
-
++event_iter;
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
for (auto& evt : map_iter.second) {
|
| 417 |
-
if (copy_op_names) {
|
| 418 |
-
// If we have found a matching parent event,
|
| 419 |
-
// then inherit some names from the parent.
|
| 420 |
-
evt.args["op_name"] = op_name;
|
| 421 |
-
evt.args["parent_name"] = parent_name;
|
| 422 |
-
}
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
merged_events.insert(merged_events.end(),
|
| 426 |
-
std::make_move_iterator(map_iter.second.begin()),
|
| 427 |
-
std::make_move_iterator(map_iter.second.end()));
|
| 428 |
-
}
|
| 429 |
-
|
| 430 |
-
// move any remaining events
|
| 431 |
-
merged_events.insert(merged_events.end(), event_iter, event_end);
|
| 432 |
-
std::swap(events, merged_events);
|
| 433 |
-
}
|
| 434 |
-
|
| 435 |
-
uint64_t client_handle_;
|
| 436 |
-
TimePoint profiling_start_time_;
|
| 437 |
-
|
| 438 |
-
public:
|
| 439 |
-
virtual bool StartProfiling(TimePoint profiling_start_time) override {
|
| 440 |
-
auto& manager = TManager::GetInstance();
|
| 441 |
-
manager.StartLogging();
|
| 442 |
-
profiling_start_time_ = profiling_start_time;
|
| 443 |
-
return true;
|
| 444 |
-
}
|
| 445 |
-
|
| 446 |
-
virtual void EndProfiling(TimePoint start_time, Events& events) override {
|
| 447 |
-
auto& manager = TManager::GetInstance();
|
| 448 |
-
std::map<uint64_t, Events> event_map;
|
| 449 |
-
manager.Consume(client_handle_, start_time, event_map);
|
| 450 |
-
MergeEvents(event_map, events);
|
| 451 |
-
}
|
| 452 |
-
|
| 453 |
-
virtual void Start(uint64_t id) override {
|
| 454 |
-
auto& manager = TManager::GetInstance();
|
| 455 |
-
manager.PushCorrelation(client_handle_, id, profiling_start_time_);
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
virtual void Stop(uint64_t) override {
|
| 459 |
-
auto& manager = TManager::GetInstance();
|
| 460 |
-
manager.PopCorrelation();
|
| 461 |
-
}
|
| 462 |
-
}; /* class GPUProfilerBase */
|
| 463 |
-
|
| 464 |
-
// Convert a pointer to a hex string
|
| 465 |
-
static inline std::string PointerToHexString(const void* ptr) {
|
| 466 |
-
std::ostringstream sstr;
|
| 467 |
-
sstr << std::hex << ptr;
|
| 468 |
-
return sstr.str();
|
| 469 |
-
}
|
| 470 |
-
|
| 471 |
-
} /* end namespace profiling */
|
| 472 |
-
} /* end namespace onnxruntime */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/gsl.h
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "gsl/gsl"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/hash_combine.h
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
namespace onnxruntime {
|
| 7 |
-
|
| 8 |
-
// Combine hash value `seed` with hash value `h`, updating `seed` in place.
|
| 9 |
-
// TODO(edgchen1) find a better implementation? e.g., see a more recent version of boost::hash_combine()
|
| 10 |
-
inline void HashCombineWithHashValue(size_t h, size_t& seed) {
|
| 11 |
-
seed ^= h + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
// Combine hash value `seed` with the hash value of `value`, updating `seed` in place.
|
| 15 |
-
// The hash value computation is specified by the `Hash` template parameter.
|
| 16 |
-
template <typename T, typename Hash = std::hash<T>>
|
| 17 |
-
inline void HashCombine(const T& value, size_t& seed) {
|
| 18 |
-
HashCombineWithHashValue(Hash{}(value), seed);
|
| 19 |
-
}
|
| 20 |
-
|
| 21 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/inlined_containers.h
DELETED
|
@@ -1,175 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <cmath>
|
| 7 |
-
|
| 8 |
-
#include "core/common/inlined_containers_fwd.h"
|
| 9 |
-
|
| 10 |
-
#ifndef DISABLE_ABSEIL
|
| 11 |
-
|
| 12 |
-
#ifdef _MSC_VER
|
| 13 |
-
#pragma warning(push)
|
| 14 |
-
// C4127: conditional expression is constant
|
| 15 |
-
#pragma warning(disable : 4127)
|
| 16 |
-
// C4324: structure was padded due to alignment specifier
|
| 17 |
-
// Usage of alignas causes some internal padding in places.
|
| 18 |
-
#pragma warning(disable : 4324)
|
| 19 |
-
#endif // _MSC_VER
|
| 20 |
-
|
| 21 |
-
#include <absl/container/flat_hash_set.h>
|
| 22 |
-
#include <absl/container/flat_hash_map.h>
|
| 23 |
-
|
| 24 |
-
#include <absl/container/node_hash_set.h>
|
| 25 |
-
#include <absl/container/node_hash_map.h>
|
| 26 |
-
|
| 27 |
-
#ifdef _MSC_VER
|
| 28 |
-
#pragma warning(pop)
|
| 29 |
-
#endif // _MSC_VER
|
| 30 |
-
|
| 31 |
-
#else // DISABLE_ABSEIL
|
| 32 |
-
|
| 33 |
-
#include <unordered_set>
|
| 34 |
-
#include <unordered_map>
|
| 35 |
-
#include <set>
|
| 36 |
-
#include <map>
|
| 37 |
-
|
| 38 |
-
#endif // DISABLE_ABSEIL
|
| 39 |
-
|
| 40 |
-
namespace onnxruntime {
|
| 41 |
-
|
| 42 |
-
#ifndef DISABLE_ABSEIL
|
| 43 |
-
// InlinedHashSet and InlinedHashMap are preferred
|
| 44 |
-
// hash based containers. They store their values in the
|
| 45 |
-
// buckets array that is allocated in one shot. It eliminates
|
| 46 |
-
// per-node new/delete calls. Always call reserve() on any hash set/map
|
| 47 |
-
// when the number of items is known in advance.
|
| 48 |
-
// This does not allocate a dummy 'end' node on default construction.
|
| 49 |
-
template <typename T, typename Allocator>
|
| 50 |
-
class InlinedHashSet : public absl::flat_hash_set<T,
|
| 51 |
-
absl::container_internal::hash_default_hash<T>,
|
| 52 |
-
absl::container_internal::hash_default_eq<T>,
|
| 53 |
-
Allocator> {
|
| 54 |
-
using Base = absl::flat_hash_set<T,
|
| 55 |
-
absl::container_internal::hash_default_hash<T>,
|
| 56 |
-
absl::container_internal::hash_default_eq<T>,
|
| 57 |
-
Allocator>;
|
| 58 |
-
|
| 59 |
-
public:
|
| 60 |
-
using Base::Base;
|
| 61 |
-
};
|
| 62 |
-
|
| 63 |
-
template <typename Key, typename Value,
|
| 64 |
-
typename Allocator>
|
| 65 |
-
class InlinedHashMap : public absl::flat_hash_map<Key, Value,
|
| 66 |
-
absl::container_internal::hash_default_hash<Key>,
|
| 67 |
-
absl::container_internal::hash_default_eq<Key>,
|
| 68 |
-
Allocator> {
|
| 69 |
-
using Base = absl::flat_hash_map<Key, Value,
|
| 70 |
-
absl::container_internal::hash_default_hash<Key>,
|
| 71 |
-
absl::container_internal::hash_default_eq<Key>,
|
| 72 |
-
Allocator>;
|
| 73 |
-
|
| 74 |
-
public:
|
| 75 |
-
using Base::Base;
|
| 76 |
-
};
|
| 77 |
-
|
| 78 |
-
// Use this hash set/map where pointer stability is required, otherwise use
|
| 79 |
-
// InlinedHashSet and InlinedHashMap
|
| 80 |
-
// This does not allocate a dummy 'end' node on default construction.
|
| 81 |
-
// Use reserve() when the number of elements is known.
|
| 82 |
-
template <typename T, typename Allocator>
|
| 83 |
-
class NodeHashSet : public absl::node_hash_set<T,
|
| 84 |
-
absl::container_internal::hash_default_hash<T>,
|
| 85 |
-
absl::container_internal::hash_default_eq<T>,
|
| 86 |
-
Allocator> {
|
| 87 |
-
using Base = absl::node_hash_set<T,
|
| 88 |
-
absl::container_internal::hash_default_hash<T>,
|
| 89 |
-
absl::container_internal::hash_default_eq<T>,
|
| 90 |
-
Allocator>;
|
| 91 |
-
|
| 92 |
-
public:
|
| 93 |
-
using Base::Base;
|
| 94 |
-
};
|
| 95 |
-
|
| 96 |
-
template <typename Key, typename Value, typename Allocator>
|
| 97 |
-
class NodeHashMap : public absl::node_hash_map<Key, Value,
|
| 98 |
-
absl::container_internal::hash_default_hash<Key>,
|
| 99 |
-
absl::container_internal::hash_default_eq<Key>,
|
| 100 |
-
Allocator> {
|
| 101 |
-
using Base = absl::node_hash_map<Key, Value,
|
| 102 |
-
absl::container_internal::hash_default_hash<Key>,
|
| 103 |
-
absl::container_internal::hash_default_eq<Key>,
|
| 104 |
-
Allocator>;
|
| 105 |
-
|
| 106 |
-
public:
|
| 107 |
-
using Base::Base;
|
| 108 |
-
};
|
| 109 |
-
|
| 110 |
-
#else // DISABLE_ABSEIL
|
| 111 |
-
|
| 112 |
-
template <typename T, typename Allocator>
|
| 113 |
-
class InlinedHashSet : public std::unordered_set<T,
|
| 114 |
-
std::hash<T>,
|
| 115 |
-
std::equal_to<T>,
|
| 116 |
-
Allocator> {
|
| 117 |
-
using Base = std::unordered_set<T,
|
| 118 |
-
std::hash<T>,
|
| 119 |
-
std::equal_to<T>,
|
| 120 |
-
Allocator>;
|
| 121 |
-
|
| 122 |
-
public:
|
| 123 |
-
using Base::Base;
|
| 124 |
-
};
|
| 125 |
-
|
| 126 |
-
template <typename Key, typename Value,
|
| 127 |
-
typename Allocator>
|
| 128 |
-
class InlinedHashMap : public std::unordered_map<Key, Value,
|
| 129 |
-
std::hash<Key>,
|
| 130 |
-
std::equal_to<Key>,
|
| 131 |
-
Allocator> {
|
| 132 |
-
using Base = std::unordered_map<Key, Value,
|
| 133 |
-
std::hash<Key>,
|
| 134 |
-
std::equal_to<Key>,
|
| 135 |
-
Allocator>;
|
| 136 |
-
|
| 137 |
-
public:
|
| 138 |
-
using Base::Base;
|
| 139 |
-
};
|
| 140 |
-
|
| 141 |
-
// Use this hash set/map where pointer stability is required, otherwise use
|
| 142 |
-
// InlinedHashSet and InlinedHashMap
|
| 143 |
-
// This does not allocate a dummy 'end' node on default construction.
|
| 144 |
-
// Use reserve() when the number of elements is known.
|
| 145 |
-
template <typename T, typename Allocator>
|
| 146 |
-
class NodeHashSet : public std::unordered_set<T,
|
| 147 |
-
std::hash<T>,
|
| 148 |
-
std::equal_to<T>,
|
| 149 |
-
Allocator> {
|
| 150 |
-
using Base = std::unordered_set<T,
|
| 151 |
-
std::hash<T>,
|
| 152 |
-
std::equal_to<T>,
|
| 153 |
-
Allocator>;
|
| 154 |
-
|
| 155 |
-
public:
|
| 156 |
-
using Base::Base;
|
| 157 |
-
};
|
| 158 |
-
|
| 159 |
-
template <typename Key, typename Value, typename Allocator>
|
| 160 |
-
class NodeHashMap : public std::unordered_map<Key, Value,
|
| 161 |
-
std::hash<Key>,
|
| 162 |
-
std::equal_to<Key>,
|
| 163 |
-
Allocator> {
|
| 164 |
-
using Base = std::unordered_map<Key, Value,
|
| 165 |
-
std::hash<Key>,
|
| 166 |
-
std::equal_to<Key>,
|
| 167 |
-
Allocator>;
|
| 168 |
-
|
| 169 |
-
public:
|
| 170 |
-
using Base::Base;
|
| 171 |
-
};
|
| 172 |
-
|
| 173 |
-
#endif // DISABLE_ABSEIL
|
| 174 |
-
|
| 175 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/inlined_containers_fwd.h
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <memory>
|
| 7 |
-
#include <utility>
|
| 8 |
-
|
| 9 |
-
#ifndef DISABLE_ABSEIL
|
| 10 |
-
#ifdef _MSC_VER
|
| 11 |
-
#pragma warning(push)
|
| 12 |
-
// C4127: conditional expression is constant
|
| 13 |
-
#pragma warning(disable : 4127)
|
| 14 |
-
// C4324: structure was padded due to alignment specifier
|
| 15 |
-
// Usage of alignas causes some internal padding in places.
|
| 16 |
-
#pragma warning(disable : 4324)
|
| 17 |
-
#else
|
| 18 |
-
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=102329#c2
|
| 19 |
-
#if !defined(__clang__) && defined(__GNUC__)
|
| 20 |
-
#pragma GCC diagnostic push
|
| 21 |
-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
| 22 |
-
#endif
|
| 23 |
-
#endif // _MSC_VER
|
| 24 |
-
|
| 25 |
-
#include <absl/container/inlined_vector.h>
|
| 26 |
-
|
| 27 |
-
#ifdef _MSC_VER
|
| 28 |
-
#pragma warning(pop)
|
| 29 |
-
#else
|
| 30 |
-
#if !defined(__clang__) && defined(__GNUC__)
|
| 31 |
-
#pragma GCC diagnostic pop
|
| 32 |
-
#endif
|
| 33 |
-
#endif // _MSC_VER
|
| 34 |
-
|
| 35 |
-
#else
|
| 36 |
-
|
| 37 |
-
#include <vector>
|
| 38 |
-
|
| 39 |
-
#endif // DISABLE_ABSEIL
|
| 40 |
-
|
| 41 |
-
// Forward declarations for contexts where abseil can not be compiled and
|
| 42 |
-
// not really needed but we want to have it in the headers that are included
|
| 43 |
-
// e.g. CUDA 10 and .CU files
|
| 44 |
-
// InlinedVector seems to be fine with old CUDA
|
| 45 |
-
|
| 46 |
-
//===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===//
|
| 47 |
-
//
|
| 48 |
-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
| 49 |
-
// See https://llvm.org/LICENSE.txt for license information.
|
| 50 |
-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
| 51 |
-
//
|
| 52 |
-
// This file contains code and comments derived from llvm/ADT/SmallVector.h
|
| 53 |
-
//
|
| 54 |
-
// Specifically CalculateInlinedVectorDefaultInlinedElements<T>() template is derived from
|
| 55 |
-
// CalculateSmallVectorDefaultInlinedElements<T>() and its comments.
|
| 56 |
-
|
| 57 |
-
namespace onnxruntime {
|
| 58 |
-
#ifndef DISABLE_ABSEIL
|
| 59 |
-
/// Inspired by LLVM SmallVector with ONNX Runtime adjustments for abseil.
|
| 60 |
-
///
|
| 61 |
-
/// Helper class for calculating the default number of inline elements for
|
| 62 |
-
/// `InlinedVector<T>`.
|
| 63 |
-
/// This produces the following on MSVC x64
|
| 64 |
-
/// int8_t -> 41
|
| 65 |
-
// int16_t -> 21
|
| 66 |
-
// int32_t -> 11
|
| 67 |
-
// int64_t -> 6
|
| 68 |
-
// std::string 40 -> 1
|
| 69 |
-
template <typename T>
|
| 70 |
-
struct CalculateInlinedVectorDefaultInlinedElements {
|
| 71 |
-
// Parameter controlling the default number of inlined elements
|
| 72 |
-
// for `InlinedVector<T>`.
|
| 73 |
-
//
|
| 74 |
-
// The default number of inlined elements ensures that
|
| 75 |
-
// 1. There is at least one inlined element.
|
| 76 |
-
// 2. `sizeof(InlinedVector<T>) <= kPreferredInlinedVectorSizeof` unless
|
| 77 |
-
// it contradicts 1.
|
| 78 |
-
static constexpr size_t kPreferredInlinedVectorSizeof = 64;
|
| 79 |
-
|
| 80 |
-
// static_assert that sizeof(T) is not "too big".
|
| 81 |
-
//
|
| 82 |
-
// Because the InlinedVector must have at least one inlined element, it is possible
|
| 83 |
-
// for an arbitrarily large inlined element to allocate an arbitrarily large
|
| 84 |
-
// amount of inline storage. So we want to call attention to these cases and
|
| 85 |
-
// make sure that users are making an intentional decision if they request a lot of inline storage.
|
| 86 |
-
//
|
| 87 |
-
// We want this assertion to trigger in pathological cases, but otherwise
|
| 88 |
-
// not be too easy to hit. To accomplish that, the cutoff is actually somewhat
|
| 89 |
-
// larger than kPreferredInlinedVectorSizeof (otherwise,
|
| 90 |
-
// `InlinedVector<InlinedVector<T>>` would be one easy way to trip it, and that
|
| 91 |
-
// pattern seems useful in practice).
|
| 92 |
-
//
|
| 93 |
-
// One wrinkle is that this assertion is in theory non-portable, since
|
| 94 |
-
// sizeof(absl::InlinedVector<T, 1>) is in general platform-dependent. However, we don't expect this
|
| 95 |
-
// to be much of an issue, because most LLVM development happens on 64-bit
|
| 96 |
-
// hosts, and therefore sizeof(T) is expected to *decrease* when compiled for
|
| 97 |
-
// 32-bit hosts, dodging the issue. The reverse situation, where development
|
| 98 |
-
// happens on a 32-bit host and then fails due to sizeof(T) *increasing* on a
|
| 99 |
-
// 64-bit host, is expected to be very rare.
|
| 100 |
-
static_assert(
|
| 101 |
-
sizeof(absl::InlinedVector<T, 1>) <= kPreferredInlinedVectorSizeof,
|
| 102 |
-
"You are trying to use a default number of inlined elements for "
|
| 103 |
-
"`InlinedVector<T>` but `sizeof(T)` is really big! Please use an "
|
| 104 |
-
"explicit number of inlined elements with `InlinedVector<T, N>` to make "
|
| 105 |
-
"sure you really want that much inline storage.");
|
| 106 |
-
|
| 107 |
-
// Discount the size of the header itself when calculating the maximum inline
|
| 108 |
-
// bytes.
|
| 109 |
-
static constexpr size_t PreferredInlineBytes =
|
| 110 |
-
kPreferredInlinedVectorSizeof - (sizeof(absl::InlinedVector<T, 1>) - sizeof(T));
|
| 111 |
-
static constexpr size_t NumElementsThatFit = PreferredInlineBytes / sizeof(T);
|
| 112 |
-
static constexpr size_t value =
|
| 113 |
-
NumElementsThatFit == 0 ? 1 : NumElementsThatFit;
|
| 114 |
-
};
|
| 115 |
-
|
| 116 |
-
// Use InlinedVector for small arrays that can fit on a stack with a default
|
| 117 |
-
// value pre-calculated.
|
| 118 |
-
// Use TensorShapeVector for shapes.
|
| 119 |
-
template <typename T,
|
| 120 |
-
size_t N = CalculateInlinedVectorDefaultInlinedElements<T>::value,
|
| 121 |
-
typename Allocator = std::allocator<T>>
|
| 122 |
-
using InlinedVector = absl::InlinedVector<T, N, Allocator>;
|
| 123 |
-
|
| 124 |
-
#else
|
| 125 |
-
|
| 126 |
-
template <typename T,
|
| 127 |
-
size_t N = 0,
|
| 128 |
-
typename Allocator = std::allocator<T>>
|
| 129 |
-
using InlinedVector = std::vector<T, Allocator>;
|
| 130 |
-
|
| 131 |
-
#endif // DISABLE_ABSEIL
|
| 132 |
-
|
| 133 |
-
template <typename T,
|
| 134 |
-
typename Allocator = std::allocator<T>>
|
| 135 |
-
class InlinedHashSet;
|
| 136 |
-
|
| 137 |
-
template <typename Key, typename Value,
|
| 138 |
-
typename Allocator = std::allocator<std::pair<const Key, Value>>>
|
| 139 |
-
class InlinedHashMap;
|
| 140 |
-
|
| 141 |
-
template <typename T, typename Allocator = std::allocator<T>>
|
| 142 |
-
class NodeHashSet;
|
| 143 |
-
|
| 144 |
-
template <typename Key, typename Value,
|
| 145 |
-
typename Allocator = std::allocator<std::pair<const Key, Value>>>
|
| 146 |
-
class NodeHashMap;
|
| 147 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/logging/capture.h
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <cstdarg>
|
| 7 |
-
#include "core/common/gsl.h"
|
| 8 |
-
#include "core/common/common.h"
|
| 9 |
-
#include "core/common/code_location.h"
|
| 10 |
-
#include "core/common/logging/severity.h"
|
| 11 |
-
|
| 12 |
-
namespace onnxruntime {
|
| 13 |
-
namespace logging {
|
| 14 |
-
|
| 15 |
-
class Logger;
|
| 16 |
-
enum class DataType;
|
| 17 |
-
|
| 18 |
-
/**
|
| 19 |
-
Class to capture the details of a log message.
|
| 20 |
-
*/
|
| 21 |
-
class Capture {
|
| 22 |
-
public:
|
| 23 |
-
/**
|
| 24 |
-
Initializes a new instance of the Capture class.
|
| 25 |
-
@param logger The logger.
|
| 26 |
-
@param severity The severity.
|
| 27 |
-
@param category The category.
|
| 28 |
-
@param dataType Type of the data.
|
| 29 |
-
@param location The file location the log message is coming from.
|
| 30 |
-
*/
|
| 31 |
-
Capture(const Logger& logger, logging::Severity severity, const char* category,
|
| 32 |
-
logging::DataType dataType, const CodeLocation& location)
|
| 33 |
-
: logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} {
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
/**
|
| 37 |
-
The stream that can capture the message via operator<<.
|
| 38 |
-
@returns Output stream.
|
| 39 |
-
*/
|
| 40 |
-
std::ostream& Stream() noexcept {
|
| 41 |
-
return stream_;
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
#ifdef _MSC_VER
|
| 45 |
-
// add SAL annotation for printf format string. requires Code Analysis to run to validate usage.
|
| 46 |
-
#define msvc_printf_check _Printf_format_string_
|
| 47 |
-
#define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang.
|
| 48 |
-
#else
|
| 49 |
-
#define msvc_printf_check
|
| 50 |
-
#endif
|
| 51 |
-
|
| 52 |
-
/**
|
| 53 |
-
Captures a printf style log message.
|
| 54 |
-
@param name="format">The printf format.
|
| 55 |
-
@param name="">Arguments to the printf format if needed.
|
| 56 |
-
@remarks
|
| 57 |
-
A maximum of 2K of output will be captured currently.
|
| 58 |
-
Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3)
|
| 59 |
-
*/
|
| 60 |
-
void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3)));
|
| 61 |
-
|
| 62 |
-
/**
|
| 63 |
-
Process a printf style log message.
|
| 64 |
-
@param format The printf format.
|
| 65 |
-
@param ... Arguments to the printf format if needed.
|
| 66 |
-
@remarks
|
| 67 |
-
A maximum of 2K of output will be captured currently.
|
| 68 |
-
Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf
|
| 69 |
-
so that something like "One string: %s", "the string" does not consider "the string"
|
| 70 |
-
to be the va_list.
|
| 71 |
-
*/
|
| 72 |
-
void ProcessPrintf(msvc_printf_check const char* format, va_list args);
|
| 73 |
-
|
| 74 |
-
logging::Severity Severity() const noexcept {
|
| 75 |
-
return severity_;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
char SeverityPrefix() const noexcept {
|
| 79 |
-
// Carefully setup so severity_ is a valid index
|
| 80 |
-
GSL_SUPPRESS(bounds .2) {
|
| 81 |
-
return logging::SEVERITY_PREFIX[static_cast<int>(severity_)];
|
| 82 |
-
}
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
const char* Category() const noexcept {
|
| 86 |
-
return category_;
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
logging::DataType DataType() const noexcept {
|
| 90 |
-
return data_type_;
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
const CodeLocation& Location() const noexcept {
|
| 94 |
-
return location_;
|
| 95 |
-
}
|
| 96 |
-
|
| 97 |
-
std::string Message() const noexcept {
|
| 98 |
-
return stream_.str();
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
~Capture();
|
| 102 |
-
|
| 103 |
-
private:
|
| 104 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture);
|
| 105 |
-
|
| 106 |
-
const Logger* logger_;
|
| 107 |
-
const logging::Severity severity_;
|
| 108 |
-
const char* category_;
|
| 109 |
-
const logging::DataType data_type_;
|
| 110 |
-
const CodeLocation location_;
|
| 111 |
-
|
| 112 |
-
std::ostringstream stream_;
|
| 113 |
-
};
|
| 114 |
-
} // namespace logging
|
| 115 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/logging/isink.h
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string>
|
| 7 |
-
|
| 8 |
-
#include "core/common/logging/logging.h"
|
| 9 |
-
|
| 10 |
-
namespace onnxruntime {
|
| 11 |
-
namespace logging {
|
| 12 |
-
class ISink {
|
| 13 |
-
public:
|
| 14 |
-
ISink() = default;
|
| 15 |
-
|
| 16 |
-
/**
|
| 17 |
-
Sends the message to the sink.
|
| 18 |
-
@param timestamp The timestamp.
|
| 19 |
-
@param logger_id The logger identifier.
|
| 20 |
-
@param message The captured message.
|
| 21 |
-
*/
|
| 22 |
-
void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) {
|
| 23 |
-
SendImpl(timestamp, logger_id, message);
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
/**
|
| 27 |
-
Sends a Profiling Event Record to the sink.
|
| 28 |
-
@param Profiling Event Record
|
| 29 |
-
*/
|
| 30 |
-
virtual void SendProfileEvent(profiling::EventRecord&) const {};
|
| 31 |
-
|
| 32 |
-
virtual ~ISink() = default;
|
| 33 |
-
|
| 34 |
-
private:
|
| 35 |
-
// Make Code Analysis happy by disabling all for now. Enable as needed.
|
| 36 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink);
|
| 37 |
-
|
| 38 |
-
virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0;
|
| 39 |
-
};
|
| 40 |
-
} // namespace logging
|
| 41 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/logging/logging.h
DELETED
|
@@ -1,337 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <atomic>
|
| 7 |
-
#include <chrono>
|
| 8 |
-
#include <climits>
|
| 9 |
-
#include <map>
|
| 10 |
-
#include <memory>
|
| 11 |
-
#include <mutex>
|
| 12 |
-
#include <string>
|
| 13 |
-
|
| 14 |
-
#include "core/common/common.h"
|
| 15 |
-
#include "core/common/profiler_common.h"
|
| 16 |
-
#include "core/common/logging/capture.h"
|
| 17 |
-
#include "core/common/logging/severity.h"
|
| 18 |
-
|
| 19 |
-
#include "core/common/logging/macros.h"
|
| 20 |
-
|
| 21 |
-
/*
|
| 22 |
-
|
| 23 |
-
Logging overview and expected usage:
|
| 24 |
-
|
| 25 |
-
At program startup:
|
| 26 |
-
* Create one or more ISink instances. If multiple, combine using composite_sink.
|
| 27 |
-
* Create a LoggingManager instance with the sink/s with is_default_instance set to true
|
| 28 |
-
* Only one instance should be created in this way, and it should remain valid for
|
| 29 |
-
until the program no longer needs to produce log output.
|
| 30 |
-
|
| 31 |
-
You can either use the static default Logger which LoggingManager will create when constructed
|
| 32 |
-
via LoggingManager::DefaultLogger(), or separate Logger instances each with different log ids
|
| 33 |
-
via LoggingManager::CreateLogger.
|
| 34 |
-
|
| 35 |
-
The log id is passed to the ISink instance with the sink determining how the log id is used
|
| 36 |
-
in the output.
|
| 37 |
-
|
| 38 |
-
LoggingManager
|
| 39 |
-
* creates the Logger instances used by the application
|
| 40 |
-
* provides a static default logger instance
|
| 41 |
-
* owns the log sink instance
|
| 42 |
-
* applies checks on severity and output of user data
|
| 43 |
-
|
| 44 |
-
The log macros create a Capture instance to capture the information to log.
|
| 45 |
-
If the severity and/or user filtering settings would prevent logging, no evaluation
|
| 46 |
-
of the log arguments will occur, so no performance cost beyond the severity and user
|
| 47 |
-
filtering check.
|
| 48 |
-
|
| 49 |
-
A sink can do further filter as needed.
|
| 50 |
-
|
| 51 |
-
*/
|
| 52 |
-
|
| 53 |
-
namespace onnxruntime {
|
| 54 |
-
|
| 55 |
-
namespace logging {
|
| 56 |
-
|
| 57 |
-
using Timestamp = std::chrono::time_point<std::chrono::system_clock>;
|
| 58 |
-
|
| 59 |
-
#ifndef NDEBUG
|
| 60 |
-
ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs.
|
| 61 |
-
#else
|
| 62 |
-
constexpr bool vlog_enabled = false; // no VLOG output
|
| 63 |
-
#endif
|
| 64 |
-
|
| 65 |
-
enum class DataType {
|
| 66 |
-
SYSTEM = 0, ///< System data.
|
| 67 |
-
USER = 1 ///< Contains potentially sensitive user data.
|
| 68 |
-
};
|
| 69 |
-
|
| 70 |
-
// Internal log categories.
|
| 71 |
-
// Logging interface takes const char* so arbitrary values can also be used.
|
| 72 |
-
struct Category {
|
| 73 |
-
static const char* onnxruntime; ///< General output
|
| 74 |
-
static const char* System; ///< Log output regarding interactions with the host system
|
| 75 |
-
// TODO: What other high level categories are meaningful? Model? Optimizer? Execution?
|
| 76 |
-
};
|
| 77 |
-
|
| 78 |
-
class ISink;
|
| 79 |
-
class Logger;
|
| 80 |
-
class Capture;
|
| 81 |
-
|
| 82 |
-
/// <summary>
|
| 83 |
-
/// The logging manager.
|
| 84 |
-
/// Owns the log sink and potentially provides a default Logger instance.
|
| 85 |
-
/// Provides filtering based on a minimum LogSeverity level, and of messages with DataType::User if enabled.
|
| 86 |
-
/// </summary>
|
| 87 |
-
class LoggingManager final {
|
| 88 |
-
public:
|
| 89 |
-
enum InstanceType {
|
| 90 |
-
Default, ///< Default instance of LoggingManager that should exist for the lifetime of the program
|
| 91 |
-
Temporal ///< Temporal instance. CreateLogger(...) should be used, however DefaultLogger() will NOT be provided via this instance.
|
| 92 |
-
};
|
| 93 |
-
|
| 94 |
-
/**
|
| 95 |
-
Initializes a new instance of the LoggingManager class.
|
| 96 |
-
@param sink The sink to write to. Use CompositeSink if you need to write to multiple places.
|
| 97 |
-
@param default_min_severity The default minimum severity. Messages with lower severity will be ignored unless
|
| 98 |
-
overridden in CreateLogger.
|
| 99 |
-
@param default_filter_user_data If set to true ignore messages with DataType::USER unless overridden in CreateLogger.
|
| 100 |
-
@param instance_type If InstanceType::Default, this is the default instance of the LoggingManager
|
| 101 |
-
and is expected to exist for the lifetime of the program.
|
| 102 |
-
It creates and owns the default logger that calls to the static DefaultLogger method return.
|
| 103 |
-
@param default_logger_id Logger Id to use for the default logger. nullptr/ignored if instance_type == Temporal.
|
| 104 |
-
@param default_max_vlog_level Default maximum level for VLOG messages to be created unless overridden in CreateLogger.
|
| 105 |
-
Requires a severity of kVERBOSE for VLOG messages to be logged.
|
| 106 |
-
*/
|
| 107 |
-
LoggingManager(std::unique_ptr<ISink> sink, Severity default_min_severity, bool default_filter_user_data,
|
| 108 |
-
InstanceType instance_type,
|
| 109 |
-
const std::string* default_logger_id = nullptr,
|
| 110 |
-
int default_max_vlog_level = -1);
|
| 111 |
-
|
| 112 |
-
/**
|
| 113 |
-
Creates a new logger instance which will use the provided logger_id and default severity and vlog levels.
|
| 114 |
-
@param logger_id The log identifier.
|
| 115 |
-
@returns A new Logger instance that the caller owns.
|
| 116 |
-
*/
|
| 117 |
-
std::unique_ptr<Logger> CreateLogger(const std::string& logger_id);
|
| 118 |
-
|
| 119 |
-
/**
|
| 120 |
-
Creates a new logger instance which will use the provided logger_id, severity and vlog levels.
|
| 121 |
-
@param logger_id The log identifier.
|
| 122 |
-
@param min_severity The minimum severity. Requests to create messages with lower severity will be ignored.
|
| 123 |
-
@param filter_user_data If set to true ignore messages with DataType::USER.
|
| 124 |
-
@param max_vlog_level Maximum level for VLOG messages to be created.
|
| 125 |
-
@returns A new Logger instance that the caller owns.
|
| 126 |
-
*/
|
| 127 |
-
std::unique_ptr<Logger> CreateLogger(const std::string& logger_id,
|
| 128 |
-
Severity min_severity, bool filter_user_data, int max_vlog_level = -1);
|
| 129 |
-
|
| 130 |
-
/**
|
| 131 |
-
Gets the default logger instance if set. Throws if no default logger is currently registered.
|
| 132 |
-
@remarks
|
| 133 |
-
Creating a LoggingManager instance with is_default_instance == true registers a default logger.
|
| 134 |
-
Note that the default logger is only valid until the LoggerManager that registered it is destroyed.
|
| 135 |
-
@returns The default logger if available.
|
| 136 |
-
*/
|
| 137 |
-
static const Logger& DefaultLogger();
|
| 138 |
-
|
| 139 |
-
/**
|
| 140 |
-
Return a boolean indicating if the default logger has been initialized
|
| 141 |
-
*/
|
| 142 |
-
static bool HasDefaultLogger() { return nullptr != s_default_logger_; }
|
| 143 |
-
|
| 144 |
-
/**
|
| 145 |
-
Change the minimum severity level for log messages to be output by the default logger.
|
| 146 |
-
@param severity The severity.
|
| 147 |
-
*/
|
| 148 |
-
static void SetDefaultLoggerSeverity(Severity severity);
|
| 149 |
-
|
| 150 |
-
/**
|
| 151 |
-
Change the maximum verbosity level for log messages to be output by the default logger.
|
| 152 |
-
@remarks
|
| 153 |
-
To activate the verbose log, the logger severity must also be set to kVERBOSE.
|
| 154 |
-
@param vlog_level The verbosity level.
|
| 155 |
-
*/
|
| 156 |
-
static void SetDefaultLoggerVerbosity(int vlog_level);
|
| 157 |
-
|
| 158 |
-
/**
|
| 159 |
-
Logs a FATAL level message and creates an exception that can be thrown with error information.
|
| 160 |
-
@param category The log category.
|
| 161 |
-
@param location The location the log message was generated.
|
| 162 |
-
@param format_str The printf format string.
|
| 163 |
-
@param ... The printf arguments.
|
| 164 |
-
@returns A new Logger instance that the caller owns.
|
| 165 |
-
*/
|
| 166 |
-
static std::exception LogFatalAndCreateException(const char* category,
|
| 167 |
-
const CodeLocation& location,
|
| 168 |
-
const char* format_str, ...);
|
| 169 |
-
|
| 170 |
-
/**
|
| 171 |
-
Logs the message using the provided logger id.
|
| 172 |
-
@param logger_id The log identifier.
|
| 173 |
-
@param message The log message.
|
| 174 |
-
*/
|
| 175 |
-
void Log(const std::string& logger_id, const Capture& message) const;
|
| 176 |
-
|
| 177 |
-
/**
|
| 178 |
-
Sends a Profiling Event Record to the sink.
|
| 179 |
-
@param Profiling Event Record
|
| 180 |
-
*/
|
| 181 |
-
void SendProfileEvent(profiling::EventRecord& eventRecord) const;
|
| 182 |
-
~LoggingManager();
|
| 183 |
-
|
| 184 |
-
private:
|
| 185 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LoggingManager);
|
| 186 |
-
|
| 187 |
-
Timestamp GetTimestamp() const noexcept;
|
| 188 |
-
void CreateDefaultLogger(const std::string& logger_id);
|
| 189 |
-
|
| 190 |
-
std::unique_ptr<ISink> sink_;
|
| 191 |
-
const Severity default_min_severity_;
|
| 192 |
-
const bool default_filter_user_data_;
|
| 193 |
-
const int default_max_vlog_level_;
|
| 194 |
-
bool owns_default_logger_;
|
| 195 |
-
|
| 196 |
-
static Logger* s_default_logger_;
|
| 197 |
-
|
| 198 |
-
struct Epochs {
|
| 199 |
-
const std::chrono::time_point<std::chrono::high_resolution_clock> high_res;
|
| 200 |
-
const std::chrono::time_point<std::chrono::system_clock> system;
|
| 201 |
-
const std::chrono::minutes localtime_offset_from_utc;
|
| 202 |
-
};
|
| 203 |
-
|
| 204 |
-
static const Epochs& GetEpochs() noexcept;
|
| 205 |
-
};
|
| 206 |
-
|
| 207 |
-
/**
|
| 208 |
-
Logger provides a per-instance log id. Everything else is passed back up to the LoggingManager
|
| 209 |
-
*/
|
| 210 |
-
class Logger {
|
| 211 |
-
public:
|
| 212 |
-
/**
|
| 213 |
-
Initializes a new instance of the Logger class.
|
| 214 |
-
@param loggingManager The logging manager.
|
| 215 |
-
@param id The identifier for messages coming from this Logger.
|
| 216 |
-
@param severity Minimum severity for messages to be created and logged.
|
| 217 |
-
@param filter_user_data Should USER data be filtered from output.
|
| 218 |
-
@param vlog_level Minimum level for VLOG messages to be created. Note that a severity of kVERBOSE must be provided
|
| 219 |
-
for VLOG messages to be logged.
|
| 220 |
-
*/
|
| 221 |
-
Logger(const LoggingManager& loggingManager, std::string id,
|
| 222 |
-
Severity severity, bool filter_user_data, int vlog_level)
|
| 223 |
-
: logging_manager_{&loggingManager},
|
| 224 |
-
id_{id},
|
| 225 |
-
min_severity_{severity},
|
| 226 |
-
filter_user_data_{filter_user_data},
|
| 227 |
-
max_vlog_level_{vlog_level} {
|
| 228 |
-
}
|
| 229 |
-
|
| 230 |
-
/**
|
| 231 |
-
Get the minimum severity level for log messages to be output.
|
| 232 |
-
@returns The severity.
|
| 233 |
-
*/
|
| 234 |
-
Severity GetSeverity() const noexcept { return min_severity_; }
|
| 235 |
-
|
| 236 |
-
/**
|
| 237 |
-
Change the minimum severity level for log messages to be output.
|
| 238 |
-
@param severity The severity.
|
| 239 |
-
*/
|
| 240 |
-
void SetSeverity(Severity severity) noexcept { min_severity_ = severity; }
|
| 241 |
-
|
| 242 |
-
/**
|
| 243 |
-
Change the maximum verbosity level for log messages to be output.
|
| 244 |
-
@remarks
|
| 245 |
-
To activate the verbose log, the logger severity must also be set to kVERBOSE.
|
| 246 |
-
@param vlog_level The verbosity.
|
| 247 |
-
*/
|
| 248 |
-
void SetVerbosity(int vlog_level) noexcept { max_vlog_level_ = vlog_level; }
|
| 249 |
-
|
| 250 |
-
/**
|
| 251 |
-
Check if output is enabled for the provided LogSeverity and DataType values.
|
| 252 |
-
@param severity The severity.
|
| 253 |
-
@param data_type Type of the data.
|
| 254 |
-
@returns True if a message with these values will be logged.
|
| 255 |
-
*/
|
| 256 |
-
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept {
|
| 257 |
-
return (severity >= min_severity_ && (data_type != DataType::USER || !filter_user_data_));
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
/**
|
| 261 |
-
Return the maximum VLOG level allowed. Disabled unless logging VLOG messages
|
| 262 |
-
*/
|
| 263 |
-
int VLOGMaxLevel() const noexcept {
|
| 264 |
-
return min_severity_ > Severity::kVERBOSE ? -1 : max_vlog_level_;
|
| 265 |
-
}
|
| 266 |
-
|
| 267 |
-
/**
|
| 268 |
-
Logs the captured message.
|
| 269 |
-
@param message The log message.
|
| 270 |
-
*/
|
| 271 |
-
void Log(const Capture& message) const {
|
| 272 |
-
logging_manager_->Log(id_, message);
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
/**
|
| 276 |
-
Sends a Profiling Event Record to the sink.
|
| 277 |
-
@param Profiling Event Record
|
| 278 |
-
*/
|
| 279 |
-
void SendProfileEvent(profiling::EventRecord& eventRecord) const {
|
| 280 |
-
logging_manager_->SendProfileEvent(eventRecord);
|
| 281 |
-
}
|
| 282 |
-
|
| 283 |
-
private:
|
| 284 |
-
const LoggingManager* logging_manager_;
|
| 285 |
-
const std::string id_;
|
| 286 |
-
Severity min_severity_;
|
| 287 |
-
const bool filter_user_data_;
|
| 288 |
-
int max_vlog_level_;
|
| 289 |
-
};
|
| 290 |
-
|
| 291 |
-
inline const Logger& LoggingManager::DefaultLogger() {
|
| 292 |
-
if (s_default_logger_ == nullptr) {
|
| 293 |
-
// fail early for attempted misuse. don't use logging macros as we have no logger.
|
| 294 |
-
ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
return *s_default_logger_;
|
| 298 |
-
}
|
| 299 |
-
|
| 300 |
-
inline void LoggingManager::SetDefaultLoggerSeverity(Severity severity) {
|
| 301 |
-
if (s_default_logger_ == nullptr) {
|
| 302 |
-
// fail early for attempted misuse. don't use logging macros as we have no logger.
|
| 303 |
-
ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
s_default_logger_->SetSeverity(severity);
|
| 307 |
-
}
|
| 308 |
-
|
| 309 |
-
inline void LoggingManager::SetDefaultLoggerVerbosity(int vlog_level) {
|
| 310 |
-
if (s_default_logger_ == nullptr) {
|
| 311 |
-
// fail early for attempted misuse. don't use logging macros as we have no logger.
|
| 312 |
-
ORT_THROW("Attempt to use DefaultLogger but none has been registered.");
|
| 313 |
-
}
|
| 314 |
-
|
| 315 |
-
s_default_logger_->SetVerbosity(vlog_level);
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
inline Timestamp LoggingManager::GetTimestamp() const noexcept {
|
| 319 |
-
static const Epochs& epochs = GetEpochs();
|
| 320 |
-
|
| 321 |
-
const auto high_res_now = std::chrono::high_resolution_clock::now();
|
| 322 |
-
return std::chrono::time_point_cast<std::chrono::system_clock::duration>(
|
| 323 |
-
epochs.system + (high_res_now - epochs.high_res) + epochs.localtime_offset_from_utc);
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
/**
|
| 327 |
-
Return the current thread id.
|
| 328 |
-
*/
|
| 329 |
-
unsigned int GetThreadId();
|
| 330 |
-
|
| 331 |
-
/**
|
| 332 |
-
Return the current process id.
|
| 333 |
-
*/
|
| 334 |
-
unsigned int GetProcessId();
|
| 335 |
-
|
| 336 |
-
} // namespace logging
|
| 337 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/logging/macros.h
DELETED
|
@@ -1,278 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
// NOTE: Don't include this file directly. Include logging.h
|
| 6 |
-
|
| 7 |
-
#define CREATE_MESSAGE(logger, severity, category, datatype) \
|
| 8 |
-
::onnxruntime::logging::Capture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE)
|
| 9 |
-
|
| 10 |
-
/*
|
| 11 |
-
Both printf and stream style logging are supported.
|
| 12 |
-
Not that printf currently has a 2K limit to the message size.
|
| 13 |
-
|
| 14 |
-
LOGS_* macros are for stream style
|
| 15 |
-
LOGF_* macros are for printf style
|
| 16 |
-
|
| 17 |
-
The Message class captures the log input, and pushes it through the logger in its destructor.
|
| 18 |
-
|
| 19 |
-
Use the *FATAL* macros if you want a Severity::kFatal message to also throw.
|
| 20 |
-
|
| 21 |
-
There are a few variants to minimize the length of the macro name required in the calling code.
|
| 22 |
-
They are optimized so the shortest names are for the (expected) most common usage. This can be
|
| 23 |
-
tweaked if needed.
|
| 24 |
-
|
| 25 |
-
Explicit logger vs LoggingManager::DefaulLogger()
|
| 26 |
-
Default is for a logger instance to be explicitly passed in.
|
| 27 |
-
The logger instance provides an identifier so that log messages from different runs can be separated.
|
| 28 |
-
|
| 29 |
-
Variants with DEFAULT in the macro name use the default logger provided by logging manager. This is
|
| 30 |
-
static so accessible from any code, provided a LoggingManager instance created with InstanceType::Default
|
| 31 |
-
exists somewhere. See logging.h for further explanation of the expected setup.
|
| 32 |
-
|
| 33 |
-
DataType
|
| 34 |
-
Default uses DataType::SYSTEM.
|
| 35 |
-
|
| 36 |
-
Variants with USER in the macro name use DataType::USER. This is data that could be PII, and may need to
|
| 37 |
-
be filtered from output. LoggingManager applies this filtering.
|
| 38 |
-
|
| 39 |
-
Category
|
| 40 |
-
Default category is ::onnxruntime::Logging::Category::onnxruntime.
|
| 41 |
-
|
| 42 |
-
If you wish to provide a different category, use variants with CATEGORY in the macro name
|
| 43 |
-
|
| 44 |
-
*/
|
| 45 |
-
|
| 46 |
-
/**
|
| 47 |
-
* Note:
|
| 48 |
-
* The stream style logging macros (something like `LOGS() << message`) are designed to be appended to.
|
| 49 |
-
* Normally, we can isolate macro code in a separate scope (e.g., `do {...} while(0)`), but here we need the macro code
|
| 50 |
-
* to interact with subsequent code (i.e., the values to log).
|
| 51 |
-
*
|
| 52 |
-
* When an unisolated conditional is involved, extra care needs to be taken to avoid unexpected parsing behavior.
|
| 53 |
-
* For example:
|
| 54 |
-
*
|
| 55 |
-
* if (enabled)
|
| 56 |
-
* Capture().Stream()
|
| 57 |
-
*
|
| 58 |
-
* is more direct, but
|
| 59 |
-
*
|
| 60 |
-
* if (!enabled) {
|
| 61 |
-
* } else Capture().Stream()
|
| 62 |
-
*
|
| 63 |
-
* ensures that the `if` does not unintentionally associate with a subsequent `else`.
|
| 64 |
-
*/
|
| 65 |
-
|
| 66 |
-
// Logging with explicit category
|
| 67 |
-
|
| 68 |
-
// iostream style logging. Capture log info in Message, and push to the logger in ~Message.
|
| 69 |
-
#define LOGS_CATEGORY(logger, severity, category) \
|
| 70 |
-
if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
|
| 71 |
-
::onnxruntime::logging::DataType::SYSTEM)) { \
|
| 72 |
-
/* do nothing */ \
|
| 73 |
-
} else \
|
| 74 |
-
CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).Stream()
|
| 75 |
-
|
| 76 |
-
#define LOGS_USER_CATEGORY(logger, severity, category) \
|
| 77 |
-
if (!(logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
|
| 78 |
-
::onnxruntime::logging::DataType::USER)) { \
|
| 79 |
-
/* do nothing */ \
|
| 80 |
-
} else \
|
| 81 |
-
CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).Stream()
|
| 82 |
-
|
| 83 |
-
// printf style logging. Capture log info in Message, and push to the logger in ~Message.
|
| 84 |
-
#define LOGF_CATEGORY(logger, severity, category, format_str, ...) \
|
| 85 |
-
do { \
|
| 86 |
-
if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
|
| 87 |
-
::onnxruntime::logging::DataType::SYSTEM)) \
|
| 88 |
-
CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM) \
|
| 89 |
-
.CapturePrintf(format_str, ##__VA_ARGS__); \
|
| 90 |
-
} while (0)
|
| 91 |
-
|
| 92 |
-
#define LOGF_USER_CATEGORY(logger, severity, category, format_str, ...) \
|
| 93 |
-
do { \
|
| 94 |
-
if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, \
|
| 95 |
-
::onnxruntime::logging::DataType::USER)) \
|
| 96 |
-
CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER) \
|
| 97 |
-
.CapturePrintf(format_str, ##__VA_ARGS__); \
|
| 98 |
-
} while (0)
|
| 99 |
-
|
| 100 |
-
// Logging with category of "onnxruntime"
|
| 101 |
-
|
| 102 |
-
#define LOGS(logger, severity) \
|
| 103 |
-
LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 104 |
-
|
| 105 |
-
#define LOGS_USER(logger, severity) \
|
| 106 |
-
LOGS_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 107 |
-
|
| 108 |
-
// printf style logging. Capture log info in Message, and push to the logger in ~Message.
|
| 109 |
-
#define LOGF(logger, severity, format_str, ...) \
|
| 110 |
-
LOGF_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 111 |
-
|
| 112 |
-
#define LOGF_USER(logger, severity, format_str, ...) \
|
| 113 |
-
LOGF_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 114 |
-
|
| 115 |
-
/*
|
| 116 |
-
Macros that use the default logger.
|
| 117 |
-
A LoggingManager instance must be currently valid for the default logger to be available.
|
| 118 |
-
*/
|
| 119 |
-
|
| 120 |
-
// Logging with explicit category
|
| 121 |
-
|
| 122 |
-
#define LOGS_DEFAULT_CATEGORY(severity, category) \
|
| 123 |
-
LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)
|
| 124 |
-
|
| 125 |
-
#define LOGS_USER_DEFAULT_CATEGORY(severity, category) \
|
| 126 |
-
LOGS_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category)
|
| 127 |
-
|
| 128 |
-
#define LOGF_DEFAULT_CATEGORY(severity, category, format_str, ...) \
|
| 129 |
-
LOGF_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__)
|
| 130 |
-
|
| 131 |
-
#define LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ...) \
|
| 132 |
-
LOGF_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__)
|
| 133 |
-
|
| 134 |
-
// Logging with category of "onnxruntime"
|
| 135 |
-
|
| 136 |
-
#define LOGS_DEFAULT(severity) \
|
| 137 |
-
LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 138 |
-
|
| 139 |
-
#define LOGS_USER_DEFAULT(severity) \
|
| 140 |
-
LOGS_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 141 |
-
|
| 142 |
-
#define LOGF_DEFAULT(severity, format_str, ...) \
|
| 143 |
-
LOGF_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 144 |
-
|
| 145 |
-
#define LOGF_USER_DEFAULT(severity, format_str, ...) \
|
| 146 |
-
LOGF_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 147 |
-
|
| 148 |
-
/*
|
| 149 |
-
Conditional logging
|
| 150 |
-
*/
|
| 151 |
-
|
| 152 |
-
// Logging with explicit category
|
| 153 |
-
|
| 154 |
-
#define LOGS_CATEGORY_IF(boolean_expression, logger, severity, category) \
|
| 155 |
-
if (!((boolean_expression) == true)) { \
|
| 156 |
-
/* do nothing */ \
|
| 157 |
-
} else \
|
| 158 |
-
LOGS_CATEGORY(logger, severity, category)
|
| 159 |
-
|
| 160 |
-
#define LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \
|
| 161 |
-
if (!((boolean_expression) == true)) { \
|
| 162 |
-
/* do nothing */ \
|
| 163 |
-
} else \
|
| 164 |
-
LOGS_DEFAULT_CATEGORY(severity, category)
|
| 165 |
-
|
| 166 |
-
#define LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, category) \
|
| 167 |
-
if (!((boolean_expression) == true)) { \
|
| 168 |
-
/* do nothing */ \
|
| 169 |
-
} else \
|
| 170 |
-
LOGS_USER_CATEGORY(logger, severity, category)
|
| 171 |
-
|
| 172 |
-
#define LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \
|
| 173 |
-
if (!((boolean_expression) == true)) { \
|
| 174 |
-
/* do nothing */ \
|
| 175 |
-
} else \
|
| 176 |
-
LOGS_USER_DEFAULT_CATEGORY(severity, category)
|
| 177 |
-
|
| 178 |
-
#define LOGF_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \
|
| 179 |
-
do { \
|
| 180 |
-
if ((boolean_expression) == true) LOGF_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \
|
| 181 |
-
} while (0)
|
| 182 |
-
|
| 183 |
-
#define LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \
|
| 184 |
-
do { \
|
| 185 |
-
if ((boolean_expression) == true) LOGF_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \
|
| 186 |
-
} while (0)
|
| 187 |
-
|
| 188 |
-
#define LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \
|
| 189 |
-
do { \
|
| 190 |
-
if ((boolean_expression) == true) LOGF_USER_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__); \
|
| 191 |
-
} while (0)
|
| 192 |
-
|
| 193 |
-
#define LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \
|
| 194 |
-
do { \
|
| 195 |
-
if ((boolean_expression) == true) LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__); \
|
| 196 |
-
} while (0)
|
| 197 |
-
|
| 198 |
-
// Logging with category of "onnxruntime"
|
| 199 |
-
|
| 200 |
-
#define LOGS_IF(boolean_expression, logger, severity) \
|
| 201 |
-
LOGS_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 202 |
-
|
| 203 |
-
#define LOGS_DEFAULT_IF(boolean_expression, severity) \
|
| 204 |
-
LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 205 |
-
|
| 206 |
-
#define LOGS_USER_IF(boolean_expression, logger, severity) \
|
| 207 |
-
LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 208 |
-
|
| 209 |
-
#define LOGS_USER_DEFAULT_IF(boolean_expression, severity) \
|
| 210 |
-
LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime)
|
| 211 |
-
|
| 212 |
-
#define LOGF_IF(boolean_expression, logger, severity, format_str, ...) \
|
| 213 |
-
LOGF_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 214 |
-
|
| 215 |
-
#define LOGF_DEFAULT_IF(boolean_expression, severity, format_str, ...) \
|
| 216 |
-
LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__)
|
| 217 |
-
|
| 218 |
-
#define LOGF_USER_IF(boolean_expression, logger, severity, format_str, ...) \
|
| 219 |
-
LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, \
|
| 220 |
-
format_str, ##__VA_ARGS__)
|
| 221 |
-
|
| 222 |
-
#define LOGF_USER_DEFAULT_IF(boolean_expression, severity, format_str, ...) \
|
| 223 |
-
LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, \
|
| 224 |
-
format_str, ##__VA_ARGS__)
|
| 225 |
-
|
| 226 |
-
/*
|
| 227 |
-
Debug verbose logging of caller provided level.
|
| 228 |
-
Disabled in Release builds.
|
| 229 |
-
Use the _USER variants for VLOG statements involving user data that may need to be filtered.
|
| 230 |
-
*/
|
| 231 |
-
#ifndef NDEBUG
|
| 232 |
-
#define VLOGS(logger, level) \
|
| 233 |
-
if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \
|
| 234 |
-
/* do nothing */ \
|
| 235 |
-
} else \
|
| 236 |
-
LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
| 237 |
-
|
| 238 |
-
#define VLOGS_USER(logger, level) \
|
| 239 |
-
if (!(::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel())) { \
|
| 240 |
-
/* do nothing */ \
|
| 241 |
-
} else \
|
| 242 |
-
LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
| 243 |
-
|
| 244 |
-
#define VLOGF(logger, level, format_str, ...) \
|
| 245 |
-
do { \
|
| 246 |
-
if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \
|
| 247 |
-
LOGF_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \
|
| 248 |
-
} while (0)
|
| 249 |
-
|
| 250 |
-
#define VLOGF_USER(logger, level, format_str, ...) \
|
| 251 |
-
do { \
|
| 252 |
-
if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \
|
| 253 |
-
LOGF_USER_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__); \
|
| 254 |
-
} while (0)
|
| 255 |
-
#else
|
| 256 |
-
// Disabled in Release builds.
|
| 257 |
-
#define VLOGS(logger, level) \
|
| 258 |
-
if constexpr (true) {} else LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
| 259 |
-
#define VLOGS_USER(logger, level) \
|
| 260 |
-
if constexpr (true) {} else LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
| 261 |
-
#define VLOGF(logger, level, format_str, ...)
|
| 262 |
-
#define VLOGF_USER(logger, level, format_str, ...)
|
| 263 |
-
#endif
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
// Default logger variants
|
| 268 |
-
#define VLOGS_DEFAULT(level) \
|
| 269 |
-
VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)
|
| 270 |
-
|
| 271 |
-
#define VLOGS_USER_DEFAULT(level) \
|
| 272 |
-
VLOGS_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)
|
| 273 |
-
|
| 274 |
-
#define VLOGF_DEFAULT(level, format_str, ...) \
|
| 275 |
-
VLOGF(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__)
|
| 276 |
-
|
| 277 |
-
#define VLOGF_USER_DEFAULT(level, format_str, ...) \
|
| 278 |
-
VLOGF_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/logging/severity.h
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
namespace onnxruntime {
|
| 7 |
-
namespace logging {
|
| 8 |
-
// mild violation of naming convention. the 'k' lets us use token concatenation in the macro
|
| 9 |
-
// ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity
|
| 10 |
-
// the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR)
|
| 11 |
-
enum class Severity {
|
| 12 |
-
kVERBOSE = 0,
|
| 13 |
-
kINFO = 1,
|
| 14 |
-
kWARNING = 2,
|
| 15 |
-
kERROR = 3,
|
| 16 |
-
kFATAL = 4
|
| 17 |
-
};
|
| 18 |
-
|
| 19 |
-
constexpr const char* SEVERITY_PREFIX = "VIWEF";
|
| 20 |
-
|
| 21 |
-
} // namespace logging
|
| 22 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/make_string.h
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
/**
|
| 2 |
-
* Copyright (c) 2016-present, Facebook, Inc.
|
| 3 |
-
*
|
| 4 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
* you may not use this file except in compliance with the License.
|
| 6 |
-
* You may obtain a copy of the License at
|
| 7 |
-
*
|
| 8 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
*
|
| 10 |
-
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
* See the License for the specific language governing permissions and
|
| 14 |
-
* limitations under the License.
|
| 15 |
-
*/
|
| 16 |
-
// Portions Copyright (c) Microsoft Corporation
|
| 17 |
-
|
| 18 |
-
#pragma once
|
| 19 |
-
|
| 20 |
-
#include <locale>
|
| 21 |
-
#include <sstream>
|
| 22 |
-
#include <type_traits>
|
| 23 |
-
|
| 24 |
-
namespace onnxruntime {
|
| 25 |
-
|
| 26 |
-
namespace detail {
|
| 27 |
-
|
| 28 |
-
inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept {
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
template <typename T>
|
| 32 |
-
inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept {
|
| 33 |
-
ss << t;
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
template <typename T, typename... Args>
|
| 37 |
-
inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept {
|
| 38 |
-
MakeStringImpl(ss, t);
|
| 39 |
-
MakeStringImpl(ss, args...);
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
// see MakeString comments for explanation of why this is necessary
|
| 43 |
-
template <typename... Args>
|
| 44 |
-
inline std::string MakeStringImpl(const Args&... args) noexcept {
|
| 45 |
-
std::ostringstream ss;
|
| 46 |
-
MakeStringImpl(ss, args...);
|
| 47 |
-
return ss.str();
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
//
|
| 51 |
-
// Infrastructure to convert char[n] to char* to reduce binary size
|
| 52 |
-
//
|
| 53 |
-
|
| 54 |
-
// default is to leave the type as is
|
| 55 |
-
template <class T>
|
| 56 |
-
struct if_char_array_make_ptr {
|
| 57 |
-
using type = T;
|
| 58 |
-
};
|
| 59 |
-
|
| 60 |
-
// specialization that matches an array reference, which is what the char array from a string literal
|
| 61 |
-
// used in a call to MakeString will be.
|
| 62 |
-
// if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded.
|
| 63 |
-
template <class T, size_t N>
|
| 64 |
-
struct if_char_array_make_ptr<T (&)[N]> {
|
| 65 |
-
// remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x],
|
| 66 |
-
// and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched.
|
| 67 |
-
using element_type = typename std::remove_const<typename std::remove_extent<T>::type>::type;
|
| 68 |
-
using type = typename std::conditional<std::is_same<char, element_type>::value, T*, T (&)[N]>::type;
|
| 69 |
-
};
|
| 70 |
-
|
| 71 |
-
// helper to make usage simpler in MakeString
|
| 72 |
-
template <class T>
|
| 73 |
-
using if_char_array_make_ptr_t = typename if_char_array_make_ptr<T>::type;
|
| 74 |
-
} // namespace detail
|
| 75 |
-
|
| 76 |
-
/**
|
| 77 |
-
* Makes a string by concatenating string representations of the arguments.
|
| 78 |
-
* This version uses the current locale.
|
| 79 |
-
*/
|
| 80 |
-
template <typename... Args>
|
| 81 |
-
std::string MakeString(const Args&... args) {
|
| 82 |
-
// We need to update the types from the MakeString template instantiation to decay any char[n] to char*.
|
| 83 |
-
// e.g. MakeString("in", "out") goes from MakeString<char[2], char[3]> to MakeStringImpl<char*, char*>
|
| 84 |
-
// so that MakeString("out", "in") will also match MakeStringImpl<char*, char*> instead of requiring
|
| 85 |
-
// MakeStringImpl<char[3], char[2]>.
|
| 86 |
-
//
|
| 87 |
-
// We have to do the type processing before any actual work, so this function purely implements the type processing.
|
| 88 |
-
// If we do not do it this way we do not get the full binary size reduction.
|
| 89 |
-
//
|
| 90 |
-
// See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover
|
| 91 |
-
// the need to do the type processing as a separate step.
|
| 92 |
-
|
| 93 |
-
return detail::MakeStringImpl(detail::if_char_array_make_ptr_t<Args const&>(args)...);
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
/**
|
| 97 |
-
* Makes a string by concatenating string representations of the arguments.
|
| 98 |
-
* This version uses std::locale::classic().
|
| 99 |
-
*/
|
| 100 |
-
template <typename... Args>
|
| 101 |
-
std::string MakeStringWithClassicLocale(const Args&... args) {
|
| 102 |
-
std::ostringstream ss;
|
| 103 |
-
ss.imbue(std::locale::classic());
|
| 104 |
-
detail::MakeStringImpl(ss, args...);
|
| 105 |
-
return ss.str();
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
// MakeString versions for already-a-string types.
|
| 109 |
-
|
| 110 |
-
inline std::string MakeString(const std::string& str) {
|
| 111 |
-
return str;
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
inline std::string MakeString(const char* cstr) {
|
| 115 |
-
return cstr;
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
inline std::string MakeStringWithClassicLocale(const std::string& str) {
|
| 119 |
-
return str;
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
inline std::string MakeStringWithClassicLocale(const char* cstr) {
|
| 123 |
-
return cstr;
|
| 124 |
-
}
|
| 125 |
-
|
| 126 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/narrow.h
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
// onnxruntime::narrow() is like gsl::narrow() but it is also available when exceptions are disabled.
|
| 7 |
-
|
| 8 |
-
#if !defined(ORT_NO_EXCEPTIONS)
|
| 9 |
-
|
| 10 |
-
#include "gsl/narrow"
|
| 11 |
-
|
| 12 |
-
namespace onnxruntime {
|
| 13 |
-
using gsl::narrow;
|
| 14 |
-
} // namespace onnxruntime
|
| 15 |
-
|
| 16 |
-
#else // ^^ !defined(ORT_NO_EXCEPTIONS) ^^ / vv defined(ORT_NO_EXCEPTIONS) vv
|
| 17 |
-
|
| 18 |
-
#include <cstdio> // std::fprintf
|
| 19 |
-
#include <exception> // std::terminate
|
| 20 |
-
#include <type_traits>
|
| 21 |
-
|
| 22 |
-
#include "gsl/util" // gsl::narrow_cast
|
| 23 |
-
|
| 24 |
-
namespace onnxruntime {
|
| 25 |
-
|
| 26 |
-
namespace detail {
|
| 27 |
-
[[noreturn]] inline void OnNarrowingError() noexcept {
|
| 28 |
-
std::fprintf(stderr, "%s", "narrowing error\n");
|
| 29 |
-
std::terminate();
|
| 30 |
-
}
|
| 31 |
-
} // namespace detail
|
| 32 |
-
|
| 33 |
-
// This implementation of onnxruntime::narrow was copied and adapted from:
|
| 34 |
-
// https://github.com/microsoft/GSL/blob/a3534567187d2edc428efd3f13466ff75fe5805c/include/gsl/narrow
|
| 35 |
-
|
| 36 |
-
// narrow() : a checked version of narrow_cast() that terminates if the cast changed the value
|
| 37 |
-
template <class T, class U, typename std::enable_if<std::is_arithmetic<T>::value>::type* = nullptr>
|
| 38 |
-
// clang-format off
|
| 39 |
-
GSL_SUPPRESS(type.1) // NO-FORMAT: attribute
|
| 40 |
-
// clang-format on
|
| 41 |
-
constexpr T narrow(U u) noexcept {
|
| 42 |
-
constexpr const bool is_different_signedness =
|
| 43 |
-
(std::is_signed<T>::value != std::is_signed<U>::value);
|
| 44 |
-
|
| 45 |
-
// clang-format off
|
| 46 |
-
GSL_SUPPRESS(es.103) // NO-FORMAT: attribute // don't overflow
|
| 47 |
-
GSL_SUPPRESS(es.104) // NO-FORMAT: attribute // don't underflow
|
| 48 |
-
GSL_SUPPRESS(p.2) // NO-FORMAT: attribute // don't rely on undefined behavior
|
| 49 |
-
// clang-format on
|
| 50 |
-
const T t = gsl::narrow_cast<T>(u); // While this is technically undefined behavior in some cases (i.e., if the source value is of floating-point type
|
| 51 |
-
// and cannot fit into the destination integral type), the resultant behavior is benign on the platforms
|
| 52 |
-
// that we target (i.e., no hardware trap representations are hit).
|
| 53 |
-
|
| 54 |
-
if (static_cast<U>(t) != u || (is_different_signedness && ((t < T{}) != (u < U{})))) {
|
| 55 |
-
detail::OnNarrowingError();
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
return t;
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
template <class T, class U, typename std::enable_if<!std::is_arithmetic<T>::value>::type* = nullptr>
|
| 62 |
-
// clang-format off
|
| 63 |
-
GSL_SUPPRESS(type.1) // NO-FORMAT: attribute
|
| 64 |
-
// clang-format on
|
| 65 |
-
constexpr T narrow(U u) noexcept {
|
| 66 |
-
const T t = gsl::narrow_cast<T>(u);
|
| 67 |
-
|
| 68 |
-
if (static_cast<U>(t) != u) {
|
| 69 |
-
detail::OnNarrowingError();
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
return t;
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
} // namespace onnxruntime
|
| 76 |
-
|
| 77 |
-
#endif // defined(ORT_NO_EXCEPTIONS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/optional.h
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
#include <optional>
|
| 6 |
-
|
| 7 |
-
namespace onnxruntime {
|
| 8 |
-
|
| 9 |
-
using std::optional;
|
| 10 |
-
|
| 11 |
-
#ifndef ORT_NO_EXCEPTIONS
|
| 12 |
-
using std::bad_optional_access;
|
| 13 |
-
#endif
|
| 14 |
-
|
| 15 |
-
using std::nullopt;
|
| 16 |
-
using std::nullopt_t;
|
| 17 |
-
|
| 18 |
-
using std::in_place;
|
| 19 |
-
using std::in_place_t;
|
| 20 |
-
|
| 21 |
-
using std::make_optional;
|
| 22 |
-
|
| 23 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/parse_string.h
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <locale>
|
| 7 |
-
#include <sstream>
|
| 8 |
-
#include <string_view>
|
| 9 |
-
#include <type_traits>
|
| 10 |
-
|
| 11 |
-
#include "core/common/common.h"
|
| 12 |
-
|
| 13 |
-
namespace onnxruntime {
|
| 14 |
-
|
| 15 |
-
/**
|
| 16 |
-
* Tries to parse a value from an entire string.
|
| 17 |
-
*/
|
| 18 |
-
template <typename T>
|
| 19 |
-
bool TryParseStringWithClassicLocale(std::string_view str, T& value) {
|
| 20 |
-
if constexpr (std::is_integral<T>::value && std::is_unsigned<T>::value) {
|
| 21 |
-
// if T is unsigned integral type, reject negative values which will wrap
|
| 22 |
-
if (!str.empty() && str[0] == '-') {
|
| 23 |
-
return false;
|
| 24 |
-
}
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
// don't allow leading whitespace
|
| 28 |
-
if (!str.empty() && std::isspace(str[0], std::locale::classic())) {
|
| 29 |
-
return false;
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
std::istringstream is{std::string{str}};
|
| 33 |
-
is.imbue(std::locale::classic());
|
| 34 |
-
T parsed_value{};
|
| 35 |
-
|
| 36 |
-
const bool parse_successful =
|
| 37 |
-
is >> parsed_value &&
|
| 38 |
-
is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters
|
| 39 |
-
if (!parse_successful) {
|
| 40 |
-
return false;
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
value = std::move(parsed_value);
|
| 44 |
-
return true;
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) {
|
| 48 |
-
value = str;
|
| 49 |
-
return true;
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) {
|
| 53 |
-
if (str == "0" || str == "False" || str == "false") {
|
| 54 |
-
value = false;
|
| 55 |
-
return true;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
if (str == "1" || str == "True" || str == "true") {
|
| 59 |
-
value = true;
|
| 60 |
-
return true;
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
return false;
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
/**
|
| 67 |
-
* Parses a value from an entire string.
|
| 68 |
-
*/
|
| 69 |
-
template <typename T>
|
| 70 |
-
Status ParseStringWithClassicLocale(std::string_view s, T& value) {
|
| 71 |
-
ORT_RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\"");
|
| 72 |
-
return Status::OK();
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
/**
|
| 76 |
-
* Parses a value from an entire string.
|
| 77 |
-
*/
|
| 78 |
-
template <typename T>
|
| 79 |
-
T ParseStringWithClassicLocale(std::string_view s) {
|
| 80 |
-
T value{};
|
| 81 |
-
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value));
|
| 82 |
-
return value;
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/profiler_common.h
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "core/common/common.h"
|
| 7 |
-
|
| 8 |
-
#include <string>
|
| 9 |
-
#include <unordered_map>
|
| 10 |
-
|
| 11 |
-
namespace onnxruntime {
|
| 12 |
-
namespace profiling {
|
| 13 |
-
|
| 14 |
-
enum EventCategory {
|
| 15 |
-
SESSION_EVENT = 0,
|
| 16 |
-
NODE_EVENT,
|
| 17 |
-
KERNEL_EVENT,
|
| 18 |
-
API_EVENT,
|
| 19 |
-
EVENT_CATEGORY_MAX
|
| 20 |
-
};
|
| 21 |
-
|
| 22 |
-
// Event descriptions for the above session events.
|
| 23 |
-
static constexpr const char* event_category_names_[EVENT_CATEGORY_MAX] = {
|
| 24 |
-
"Session",
|
| 25 |
-
"Node",
|
| 26 |
-
"Kernel",
|
| 27 |
-
"Api"};
|
| 28 |
-
|
| 29 |
-
// Timing record for all events.
|
| 30 |
-
struct EventRecord {
|
| 31 |
-
EventRecord() = default;
|
| 32 |
-
EventRecord(EventCategory category,
|
| 33 |
-
int process_id,
|
| 34 |
-
int thread_id,
|
| 35 |
-
std::string&& event_name,
|
| 36 |
-
long long time_stamp,
|
| 37 |
-
long long duration,
|
| 38 |
-
std::unordered_map<std::string, std::string>&& event_args)
|
| 39 |
-
: cat(category),
|
| 40 |
-
pid(process_id),
|
| 41 |
-
tid(thread_id),
|
| 42 |
-
name(std::move(event_name)),
|
| 43 |
-
ts(time_stamp),
|
| 44 |
-
dur(duration),
|
| 45 |
-
args(std::move(event_args)) {}
|
| 46 |
-
|
| 47 |
-
EventRecord(EventCategory category,
|
| 48 |
-
int process_id,
|
| 49 |
-
int thread_id,
|
| 50 |
-
const std::string& event_name,
|
| 51 |
-
long long time_stamp,
|
| 52 |
-
long long duration,
|
| 53 |
-
const std::unordered_map<std::string, std::string>& event_args)
|
| 54 |
-
: cat(category),
|
| 55 |
-
pid(process_id),
|
| 56 |
-
tid(thread_id),
|
| 57 |
-
name(event_name),
|
| 58 |
-
ts(time_stamp),
|
| 59 |
-
dur(duration),
|
| 60 |
-
args(event_args) {}
|
| 61 |
-
|
| 62 |
-
EventRecord(const EventRecord& other) = default;
|
| 63 |
-
EventRecord(EventRecord&& other) noexcept = default;
|
| 64 |
-
EventRecord& operator=(const EventRecord& other) = default;
|
| 65 |
-
EventRecord& operator=(EventRecord&& other) = default;
|
| 66 |
-
|
| 67 |
-
EventCategory cat = EventCategory::API_EVENT;
|
| 68 |
-
int pid = -1;
|
| 69 |
-
int tid = -1;
|
| 70 |
-
std::string name{};
|
| 71 |
-
long long ts = 0;
|
| 72 |
-
long long dur = 0;
|
| 73 |
-
std::unordered_map<std::string, std::string> args{};
|
| 74 |
-
};
|
| 75 |
-
|
| 76 |
-
using Events = std::vector<EventRecord>;
|
| 77 |
-
|
| 78 |
-
//Execution Provider Profiler
|
| 79 |
-
class EpProfiler {
|
| 80 |
-
public:
|
| 81 |
-
virtual ~EpProfiler() = default;
|
| 82 |
-
virtual bool StartProfiling(TimePoint profiling_start_time) = 0; // called when profiling starts
|
| 83 |
-
virtual void EndProfiling(TimePoint start_time, Events& events) = 0; // called when profiling ends, save all captures numbers to "events"
|
| 84 |
-
virtual void Start(uint64_t){}; // called before op start, accept an id as argument to identify the op
|
| 85 |
-
virtual void Stop(uint64_t){}; // called after op stop, accept an id as argument to identify the op
|
| 86 |
-
};
|
| 87 |
-
|
| 88 |
-
// Demangle C++ symbols
|
| 89 |
-
std::string demangle(const char* name);
|
| 90 |
-
std::string demangle(const std::string& name);
|
| 91 |
-
|
| 92 |
-
} // namespace profiling
|
| 93 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/span_utils.h
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <algorithm>
|
| 7 |
-
|
| 8 |
-
#include "core/common/gsl.h"
|
| 9 |
-
|
| 10 |
-
namespace onnxruntime {
|
| 11 |
-
|
| 12 |
-
// AsSpan inspired by Fekir's Blog https://fekir.info/post/span-the-missing-constructor/
|
| 13 |
-
// Used under MIT license
|
| 14 |
-
|
| 15 |
-
// Use AsSpan for less typing on any container including initializer list to create a span
|
| 16 |
-
// (unnamed, untyped initializer list does not automatically convert to gsl::span).
|
| 17 |
-
// {1, 2, 3} as such does not have a type
|
| 18 |
-
// (see https://scottmeyers.blogspot.com/2014/03/if-braced-initializers-have-no-type-why.html)
|
| 19 |
-
//
|
| 20 |
-
// Example: AsSpan({1, 2, 3}) results in gsl::span<const int>
|
| 21 |
-
//
|
| 22 |
-
// The above would deduce to std::initializer_list<int> and the result is gsl::span<const int>
|
| 23 |
-
//
|
| 24 |
-
// AsSpan<int64_t>({1, 2, 3}) produces gsl::span<const int64_t>
|
| 25 |
-
//
|
| 26 |
-
// We can also do std::array<int64_t, 3>{1, 2, 3} that can be automatically converted to span
|
| 27 |
-
// without memory allocation.
|
| 28 |
-
//
|
| 29 |
-
// If type conversion is not required, then for C++17 std::array template parameters are
|
| 30 |
-
// auto-deduced. Example: std::array{1, 2, 3}.
|
| 31 |
-
// We are aiming at not allocating memory dynamically.
|
| 32 |
-
|
| 33 |
-
namespace details {
|
| 34 |
-
template <class P>
|
| 35 |
-
constexpr auto AsSpanImpl(P* p, size_t s) {
|
| 36 |
-
return gsl::span<P>(p, s);
|
| 37 |
-
}
|
| 38 |
-
} // namespace details
|
| 39 |
-
|
| 40 |
-
template <class C>
|
| 41 |
-
constexpr auto AsSpan(C& c) {
|
| 42 |
-
return details::AsSpanImpl(c.data(), c.size());
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
template <class C>
|
| 46 |
-
constexpr auto AsSpan(const C& c) {
|
| 47 |
-
return details::AsSpanImpl(c.data(), c.size());
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
template <class C>
|
| 51 |
-
constexpr auto AsSpan(C&& c) {
|
| 52 |
-
return details::AsSpanImpl(c.data(), c.size());
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
template <class T>
|
| 56 |
-
constexpr auto AsSpan(std::initializer_list<T> c) {
|
| 57 |
-
return details::AsSpanImpl(c.begin(), c.size());
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
template <class T, size_t N>
|
| 61 |
-
constexpr auto AsSpan(T (&arr)[N]) {
|
| 62 |
-
return details::AsSpanImpl(arr, N);
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
template <class T, size_t N>
|
| 66 |
-
constexpr auto AsSpan(const T (&arr)[N]) {
|
| 67 |
-
return details::AsSpanImpl(arr, N);
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
template <class T>
|
| 71 |
-
inline gsl::span<const T> EmptySpan() { return gsl::span<const T>(); }
|
| 72 |
-
|
| 73 |
-
template <class U, class T>
|
| 74 |
-
[[nodiscard]] inline gsl::span<U> ReinterpretAsSpan(gsl::span<T> src) {
|
| 75 |
-
// adapted from gsl-lite span::as_span():
|
| 76 |
-
// https://github.com/gsl-lite/gsl-lite/blob/4720a2980a30da085b4ddb4a0ea2a71af7351a48/include/gsl/gsl-lite.hpp#L4102-L4108
|
| 77 |
-
Expects(src.size_bytes() % sizeof(U) == 0);
|
| 78 |
-
return gsl::span<U>(reinterpret_cast<U*>(src.data()), src.size_bytes() / sizeof(U));
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
template <class T1, size_t Extent1, class T2, size_t Extent2>
|
| 82 |
-
[[nodiscard]] inline bool SpanEq(gsl::span<T1, Extent1> a, gsl::span<T2, Extent2> b) {
|
| 83 |
-
static_assert(std::is_same_v<std::remove_const_t<T1>, std::remove_const_t<T2>>,
|
| 84 |
-
"T1 and T2 should be the same type except for const qualification");
|
| 85 |
-
return std::equal(a.begin(), a.end(), b.begin(), b.end());
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/spin_pause.h
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#if defined(_M_AMD64)
|
| 7 |
-
#include <intrin.h>
|
| 8 |
-
#endif
|
| 9 |
-
|
| 10 |
-
#if defined(__x86_64__)
|
| 11 |
-
#include <xmmintrin.h>
|
| 12 |
-
#endif
|
| 13 |
-
|
| 14 |
-
namespace onnxruntime {
|
| 15 |
-
|
| 16 |
-
namespace concurrency {
|
| 17 |
-
|
| 18 |
-
// Intrinsic to use in spin-loops
|
| 19 |
-
|
| 20 |
-
inline void SpinPause() {
|
| 21 |
-
#if defined(_M_AMD64) || defined(__x86_64__)
|
| 22 |
-
_mm_pause();
|
| 23 |
-
#endif
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
} // namespace concurrency
|
| 27 |
-
|
| 28 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/status.h
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
-
you may not use this file except in compliance with the License.
|
| 4 |
-
You may obtain a copy of the License at
|
| 5 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
-
Unless required by applicable law or agreed to in writing, software
|
| 7 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
-
See the License for the specific language governing permissions and
|
| 10 |
-
limitations under the License.
|
| 11 |
-
==============================================================================*/
|
| 12 |
-
// Modifications Copyright (c) Microsoft.
|
| 13 |
-
|
| 14 |
-
#pragma once
|
| 15 |
-
|
| 16 |
-
#include <memory>
|
| 17 |
-
#include <ostream>
|
| 18 |
-
#include <string>
|
| 19 |
-
#ifdef _WIN32
|
| 20 |
-
#include <winerror.h>
|
| 21 |
-
#endif
|
| 22 |
-
#include "core/common/gsl.h"
|
| 23 |
-
namespace onnxruntime {
|
| 24 |
-
namespace common {
|
| 25 |
-
|
| 26 |
-
enum StatusCategory {
|
| 27 |
-
NONE = 0,
|
| 28 |
-
SYSTEM = 1,
|
| 29 |
-
ONNXRUNTIME = 2,
|
| 30 |
-
};
|
| 31 |
-
|
| 32 |
-
/**
|
| 33 |
-
Error code for ONNXRuntime.
|
| 34 |
-
*/
|
| 35 |
-
enum StatusCode {
|
| 36 |
-
OK = 0,
|
| 37 |
-
FAIL = 1,
|
| 38 |
-
INVALID_ARGUMENT = 2,
|
| 39 |
-
NO_SUCHFILE = 3,
|
| 40 |
-
NO_MODEL = 4,
|
| 41 |
-
ENGINE_ERROR = 5,
|
| 42 |
-
RUNTIME_EXCEPTION = 6,
|
| 43 |
-
INVALID_PROTOBUF = 7,
|
| 44 |
-
MODEL_LOADED = 8,
|
| 45 |
-
NOT_IMPLEMENTED = 9,
|
| 46 |
-
INVALID_GRAPH = 10,
|
| 47 |
-
EP_FAIL = 11
|
| 48 |
-
};
|
| 49 |
-
|
| 50 |
-
constexpr const char* StatusCodeToString(StatusCode status) noexcept {
|
| 51 |
-
switch (status) {
|
| 52 |
-
case StatusCode::OK:
|
| 53 |
-
return "SUCCESS";
|
| 54 |
-
case StatusCode::FAIL:
|
| 55 |
-
return "FAIL";
|
| 56 |
-
case StatusCode::INVALID_ARGUMENT:
|
| 57 |
-
return "INVALID_ARGUMENT";
|
| 58 |
-
case StatusCode::NO_SUCHFILE:
|
| 59 |
-
return "NO_SUCHFILE";
|
| 60 |
-
case StatusCode::NO_MODEL:
|
| 61 |
-
return "NO_MODEL";
|
| 62 |
-
case StatusCode::ENGINE_ERROR:
|
| 63 |
-
return "ENGINE_ERROR";
|
| 64 |
-
case StatusCode::RUNTIME_EXCEPTION:
|
| 65 |
-
return "RUNTIME_EXCEPTION";
|
| 66 |
-
case StatusCode::INVALID_PROTOBUF:
|
| 67 |
-
return "INVALID_PROTOBUF";
|
| 68 |
-
case StatusCode::MODEL_LOADED:
|
| 69 |
-
return "MODEL_LOADED";
|
| 70 |
-
case StatusCode::NOT_IMPLEMENTED:
|
| 71 |
-
return "NOT_IMPLEMENTED";
|
| 72 |
-
case StatusCode::INVALID_GRAPH:
|
| 73 |
-
return "INVALID_GRAPH";
|
| 74 |
-
case StatusCode::EP_FAIL:
|
| 75 |
-
return "EP_FAIL";
|
| 76 |
-
default:
|
| 77 |
-
return "GENERAL ERROR";
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
#ifdef _WIN32
|
| 82 |
-
constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
|
| 83 |
-
switch (status) {
|
| 84 |
-
case StatusCode::OK:
|
| 85 |
-
return S_OK;
|
| 86 |
-
case StatusCode::FAIL:
|
| 87 |
-
return E_FAIL;
|
| 88 |
-
case StatusCode::INVALID_ARGUMENT:
|
| 89 |
-
return E_INVALIDARG;
|
| 90 |
-
case StatusCode::NO_SUCHFILE:
|
| 91 |
-
return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
|
| 92 |
-
case StatusCode::NO_MODEL:
|
| 93 |
-
return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
|
| 94 |
-
case StatusCode::ENGINE_ERROR:
|
| 95 |
-
return E_FAIL;
|
| 96 |
-
case StatusCode::RUNTIME_EXCEPTION:
|
| 97 |
-
return E_FAIL;
|
| 98 |
-
case StatusCode::INVALID_PROTOBUF:
|
| 99 |
-
return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
|
| 100 |
-
case StatusCode::MODEL_LOADED:
|
| 101 |
-
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
|
| 102 |
-
case StatusCode::NOT_IMPLEMENTED:
|
| 103 |
-
return E_NOTIMPL;
|
| 104 |
-
case StatusCode::INVALID_GRAPH:
|
| 105 |
-
return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
|
| 106 |
-
case StatusCode::EP_FAIL:
|
| 107 |
-
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
|
| 108 |
-
default:
|
| 109 |
-
return E_FAIL;
|
| 110 |
-
}
|
| 111 |
-
}
|
| 112 |
-
#endif
|
| 113 |
-
|
| 114 |
-
class [[nodiscard]] Status {
|
| 115 |
-
public:
|
| 116 |
-
Status() noexcept = default;
|
| 117 |
-
|
| 118 |
-
Status(StatusCategory category, int code, const std::string& msg);
|
| 119 |
-
|
| 120 |
-
Status(StatusCategory category, int code, const char* msg);
|
| 121 |
-
|
| 122 |
-
Status(StatusCategory category, int code);
|
| 123 |
-
|
| 124 |
-
GSL_SUPPRESS(r.11)
|
| 125 |
-
Status(const Status& other)
|
| 126 |
-
: state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {}
|
| 127 |
-
GSL_SUPPRESS(r.11)
|
| 128 |
-
Status& operator=(const Status& other) {
|
| 129 |
-
if (state_ != other.state_) {
|
| 130 |
-
if (other.state_ == nullptr) {
|
| 131 |
-
state_.reset();
|
| 132 |
-
} else {
|
| 133 |
-
state_.reset(new State(*other.state_));
|
| 134 |
-
}
|
| 135 |
-
}
|
| 136 |
-
return *this;
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
Status(Status&&) = default;
|
| 140 |
-
Status& operator=(Status&&) = default;
|
| 141 |
-
~Status() = default;
|
| 142 |
-
|
| 143 |
-
bool IsOK() const {
|
| 144 |
-
return (state_ == nullptr);
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
int Code() const noexcept;
|
| 148 |
-
|
| 149 |
-
StatusCategory Category() const noexcept;
|
| 150 |
-
|
| 151 |
-
const std::string& ErrorMessage() const noexcept;
|
| 152 |
-
|
| 153 |
-
std::string ToString() const;
|
| 154 |
-
|
| 155 |
-
bool operator==(const Status& other) const {
|
| 156 |
-
return (this->state_ == other.state_) || (ToString() == other.ToString());
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
bool operator!=(const Status& other) const {
|
| 160 |
-
return !(*this == other);
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
static Status OK() {
|
| 164 |
-
return Status();
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
private:
|
| 168 |
-
static const std::string& EmptyString() noexcept;
|
| 169 |
-
|
| 170 |
-
struct State {
|
| 171 |
-
State(StatusCategory cat0, int code0, const std::string& msg0)
|
| 172 |
-
: category(cat0), code(code0), msg(msg0) {}
|
| 173 |
-
|
| 174 |
-
State(StatusCategory cat0, int code0, const char* msg0)
|
| 175 |
-
: category(cat0), code(code0), msg(msg0) {}
|
| 176 |
-
|
| 177 |
-
const StatusCategory category;
|
| 178 |
-
const int code;
|
| 179 |
-
const std::string msg;
|
| 180 |
-
};
|
| 181 |
-
|
| 182 |
-
// As long as Code() is OK, state_ == nullptr.
|
| 183 |
-
std::unique_ptr<State> state_;
|
| 184 |
-
};
|
| 185 |
-
|
| 186 |
-
inline std::ostream& operator<<(std::ostream& out, const Status& status) {
|
| 187 |
-
return out << status.ToString();
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
} // namespace common
|
| 191 |
-
|
| 192 |
-
// make Status directly available in the onnxruntime namespace as it is widely used
|
| 193 |
-
using common::Status;
|
| 194 |
-
|
| 195 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/common/string_helper.h
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
#include <string>
|
| 6 |
-
|
| 7 |
-
// forward declaration
|
| 8 |
-
struct OrtAllocator;
|
| 9 |
-
namespace onnxruntime {
|
| 10 |
-
char* StrDup(const std::string& str, OrtAllocator* allocator);
|
| 11 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/alloc_kind.h
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
#include <iosfwd>
|
| 6 |
-
|
| 7 |
-
namespace onnxruntime {
|
| 8 |
-
// The ml-Values fall into the following categories with respect to their
|
| 9 |
-
// memory management:
|
| 10 |
-
// - inference inputs: owned (allocated and freed) by caller, and is by
|
| 11 |
-
// default read-only by the runtime.
|
| 12 |
-
// - inference outputs: allocated by runtime, ownership transferred to
|
| 13 |
-
// caller. TODO: Make sure this semantics is clear in InferenceSession API.
|
| 14 |
-
// - weights (constant tensors): can be allocated once (statically), and
|
| 15 |
-
// reused by all inference calls within an InferenceSession.
|
| 16 |
-
// - tensor values: The lifetimes of these tensor-values are statically
|
| 17 |
-
// determined, which is used for memory reuse/sharing optimizations. The
|
| 18 |
-
// runtime allocates/frees these values at the right time (as determined
|
| 19 |
-
// by the static allocation plan). Note that this is simplified since we
|
| 20 |
-
// do not try to optimize for "slice" like ops, where we may be able to
|
| 21 |
-
// conditionally reuse memory/data in some cases but not others.
|
| 22 |
-
// Generalizing this is future work.
|
| 23 |
-
|
| 24 |
-
enum class AllocKind {
|
| 25 |
-
kNotSet = -1,
|
| 26 |
-
kAllocate = 0,
|
| 27 |
-
kReuse = 1,
|
| 28 |
-
kPreExisting = 2,
|
| 29 |
-
kAllocateStatically = 3,
|
| 30 |
-
kAllocateOutput = 4,
|
| 31 |
-
kShare = 5,
|
| 32 |
-
kAllocatedExternally = 6
|
| 33 |
-
};
|
| 34 |
-
|
| 35 |
-
std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind);
|
| 36 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/allocator.h
DELETED
|
@@ -1,194 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "core/common/common.h"
|
| 7 |
-
#include "core/framework/allocator_stats.h"
|
| 8 |
-
#include "core/session/onnxruntime_c_api.h"
|
| 9 |
-
#include "ortdevice.h"
|
| 10 |
-
#include "ortmemoryinfo.h"
|
| 11 |
-
|
| 12 |
-
// This configures the arena based allocator used by ORT
|
| 13 |
-
// See docs/C_API.md for details on what these mean and how to choose these values
|
| 14 |
-
struct OrtArenaCfg {
|
| 15 |
-
OrtArenaCfg() : max_mem(0),
|
| 16 |
-
arena_extend_strategy(-1),
|
| 17 |
-
initial_chunk_size_bytes(-1),
|
| 18 |
-
max_dead_bytes_per_chunk(-1),
|
| 19 |
-
initial_growth_chunk_size_bytes(-1) {}
|
| 20 |
-
OrtArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
|
| 21 |
-
int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes)
|
| 22 |
-
: max_mem(max_mem),
|
| 23 |
-
arena_extend_strategy(arena_extend_strategy),
|
| 24 |
-
initial_chunk_size_bytes(initial_chunk_size_bytes),
|
| 25 |
-
max_dead_bytes_per_chunk(max_dead_bytes_per_chunk),
|
| 26 |
-
initial_growth_chunk_size_bytes(initial_growth_chunk_size_bytes) {}
|
| 27 |
-
|
| 28 |
-
size_t max_mem; // use 0 to allow ORT to choose the default
|
| 29 |
-
int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
|
| 30 |
-
int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
|
| 31 |
-
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
|
| 32 |
-
int initial_growth_chunk_size_bytes; // use -1 to allow ORT to choose the default
|
| 33 |
-
};
|
| 34 |
-
|
| 35 |
-
namespace onnxruntime {
|
| 36 |
-
constexpr const char* CPU = "Cpu";
|
| 37 |
-
constexpr const char* CUDA = "Cuda";
|
| 38 |
-
constexpr const char* CUDA_PINNED = "CudaPinned";
|
| 39 |
-
constexpr const char* CANN = "Cann";
|
| 40 |
-
constexpr const char* CANN_PINNED = "CannPinned";
|
| 41 |
-
constexpr const char* DML = "DML";
|
| 42 |
-
constexpr const char* HIP = "Hip";
|
| 43 |
-
constexpr const char* HIP_PINNED = "HipPinned";
|
| 44 |
-
constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
|
| 45 |
-
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
|
| 46 |
-
|
| 47 |
-
constexpr size_t kAllocAlignment = 256;
|
| 48 |
-
|
| 49 |
-
class IAllocator;
|
| 50 |
-
class Stream;
|
| 51 |
-
namespace synchronize {
|
| 52 |
-
class Notification;
|
| 53 |
-
}
|
| 54 |
-
using WaitNotificationFn = std::function<void(Stream&, synchronize::Notification&)>;
|
| 55 |
-
void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn);
|
| 56 |
-
|
| 57 |
-
template <typename T>
|
| 58 |
-
using IAllocatorUniquePtr = std::unique_ptr<T, std::function<void(T*)>>;
|
| 59 |
-
|
| 60 |
-
class IAllocator {
|
| 61 |
-
public:
|
| 62 |
-
IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {}
|
| 63 |
-
virtual ~IAllocator() = default;
|
| 64 |
-
/**
|
| 65 |
-
@remarks Use SafeInt when calculating the size of memory to allocate using Alloc.
|
| 66 |
-
*/
|
| 67 |
-
virtual void* Alloc(size_t size) = 0;
|
| 68 |
-
|
| 69 |
-
virtual void Free(void* p) = 0;
|
| 70 |
-
|
| 71 |
-
// TODO: Find a better name than Reserve() and update in all places.
|
| 72 |
-
// Reserve() is an interface exposed for an implementation of IAllocator
|
| 73 |
-
// to optionally implement some allocation logic that by-passes any arena-based
|
| 74 |
-
// logic that may be housed in the Alloc() implementation.
|
| 75 |
-
// There are SessionOptions config(s) that allow users to allocate some memory
|
| 76 |
-
// by-passing arena-based logic.
|
| 77 |
-
// By default, the base implementation just calls Alloc().
|
| 78 |
-
virtual void* Reserve(size_t size) { return Alloc(size); }
|
| 79 |
-
|
| 80 |
-
const OrtMemoryInfo& Info() const { return memory_info_; };
|
| 81 |
-
|
| 82 |
-
// Each implementation of IAllocator can override and provide their own implementation
|
| 83 |
-
virtual void GetStats(AllocatorStats* /*stats*/) { return; }
|
| 84 |
-
|
| 85 |
-
static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept {
|
| 86 |
-
return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out);
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
/**
|
| 90 |
-
* Calculate the memory size for an array. The size is bounds checked using SafeInt.
|
| 91 |
-
* \tparam alignment must be power of 2
|
| 92 |
-
* \param nmemb Number of members or elements in the array
|
| 93 |
-
* \param size Size of each element
|
| 94 |
-
* \param out Total size required after any alignment is applied
|
| 95 |
-
* \return true, successful. false, overflow
|
| 96 |
-
*/
|
| 97 |
-
[[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept;
|
| 98 |
-
|
| 99 |
-
/**
|
| 100 |
-
* https://cwe.mitre.org/data/definitions/190.html
|
| 101 |
-
* \param alignment must be power of 2
|
| 102 |
-
* \param nmemb Number of members or elements in the array
|
| 103 |
-
* \param size Size of each element
|
| 104 |
-
* \param out Total size required after any alignment is applied
|
| 105 |
-
* \return true, successful. false, overflow
|
| 106 |
-
* \remarks This was the original API and was implemented in the header. Replaced with the above version
|
| 107 |
-
* implemented in the .cc file so that the SafeInt dependency is internal.
|
| 108 |
-
*/
|
| 109 |
-
template <size_t alignment>
|
| 110 |
-
[[nodiscard]] static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept;
|
| 111 |
-
|
| 112 |
-
/**
|
| 113 |
-
* allocate memory for an array which has nmemb items of data, each size bytes long
|
| 114 |
-
*/
|
| 115 |
-
void* AllocArray(size_t nmemb, size_t size) {
|
| 116 |
-
size_t len;
|
| 117 |
-
if (!CalcMemSizeForArray(nmemb, size, &len))
|
| 118 |
-
return nullptr;
|
| 119 |
-
return Alloc(len);
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
/**
|
| 123 |
-
* allocate memory for an array which has nmemb items of data, each size bytes long
|
| 124 |
-
*/
|
| 125 |
-
template <size_t alignment>
|
| 126 |
-
void* AllocArrayWithAlignment(size_t nmemb, size_t size) {
|
| 127 |
-
size_t len;
|
| 128 |
-
if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len))
|
| 129 |
-
return nullptr;
|
| 130 |
-
return Alloc(len);
|
| 131 |
-
}
|
| 132 |
-
|
| 133 |
-
/**
|
| 134 |
-
Create a std::unique_ptr that is allocated and freed by the provided IAllocator.
|
| 135 |
-
@param allocator The allocator.
|
| 136 |
-
@param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate.
|
| 137 |
-
@param use_reserve If true, call Reserve() instead of Alloc() to allocate memory.
|
| 138 |
-
@param stream Which stream instance allocated chunk will be used with.
|
| 139 |
-
@param wait_fn If the allocator want to dynamic reuse a chunk from another stream, use this wait_fn to sync on
|
| 140 |
-
the target stream to make the reuse safe.
|
| 141 |
-
@returns std::unique_ptr with allocated memory and deleter.
|
| 142 |
-
*/
|
| 143 |
-
template <typename T>
|
| 144 |
-
static IAllocatorUniquePtr<T> MakeUniquePtr(std::shared_ptr<IAllocator> allocator, size_t count_or_bytes,
|
| 145 |
-
bool use_reserve = false,
|
| 146 |
-
Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) {
|
| 147 |
-
if (allocator == nullptr) return nullptr;
|
| 148 |
-
// for now limit to fundamental types. we could support others, but to do so either we or the caller
|
| 149 |
-
// needs to call the dtor for the objects, for buffers allocated on device we don't have destructor
|
| 150 |
-
// static_assert(std::is_fundamental<T>::value, "Fundamental type required as no destructors are called.");
|
| 151 |
-
|
| 152 |
-
size_t alloc_size = count_or_bytes;
|
| 153 |
-
|
| 154 |
-
// if T is not void, 'count_or_bytes' == number of items so allow for that
|
| 155 |
-
if constexpr (!std::is_void<T>::value) {
|
| 156 |
-
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
|
| 157 |
-
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
|
| 158 |
-
if (!CalcMemSizeForArray(
|
| 159 |
-
count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type), &alloc_size)) {
|
| 160 |
-
return nullptr;
|
| 161 |
-
}
|
| 162 |
-
}
|
| 163 |
-
|
| 164 |
-
// allocate
|
| 165 |
-
T* p = static_cast<T*>(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn)));
|
| 166 |
-
return IAllocatorUniquePtr<T>{
|
| 167 |
-
p,
|
| 168 |
-
[allocator = std::move(allocator)](T* p) { allocator->Free(p); }};
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
private:
|
| 172 |
-
OrtMemoryInfo memory_info_;
|
| 173 |
-
};
|
| 174 |
-
|
| 175 |
-
template <size_t alignment>
|
| 176 |
-
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept {
|
| 177 |
-
return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out);
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
class CPUAllocator : public IAllocator {
|
| 181 |
-
public:
|
| 182 |
-
explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IAllocator(memory_info) {}
|
| 183 |
-
|
| 184 |
-
CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {}
|
| 185 |
-
|
| 186 |
-
void* Alloc(size_t size) override;
|
| 187 |
-
void Free(void* p) override;
|
| 188 |
-
};
|
| 189 |
-
|
| 190 |
-
using AllocatorPtr = std::shared_ptr<IAllocator>;
|
| 191 |
-
|
| 192 |
-
void* AllocatorDefaultAlloc(size_t size);
|
| 193 |
-
void AllocatorDefaultFree(void* p);
|
| 194 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/buffer_deleter.h
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "core/framework/allocator.h"
|
| 7 |
-
|
| 8 |
-
namespace onnxruntime {
|
| 9 |
-
|
| 10 |
-
// TODO: Do we need this class or is IAllocator::MakeUniquePtr sufficient/better
|
| 11 |
-
class BufferDeleter {
|
| 12 |
-
public:
|
| 13 |
-
BufferDeleter() = default;
|
| 14 |
-
explicit BufferDeleter(AllocatorPtr alloc)
|
| 15 |
-
: alloc_(std::move(alloc)) {}
|
| 16 |
-
|
| 17 |
-
void operator()(void* p) const {
|
| 18 |
-
if (alloc_)
|
| 19 |
-
alloc_->Free(p);
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
private:
|
| 23 |
-
// TODO: we may need consider the lifetime of alloc carefully
|
| 24 |
-
// The alloc_ here is the allocator that used to allocate the buffer
|
| 25 |
-
// And need go with the unique_ptr together. If it is using our internal
|
| 26 |
-
// allocator, it is ok as our allocators are global managed. But if it
|
| 27 |
-
// is provide by user, user need to be very careful about it.
|
| 28 |
-
// A weak_ptr may be a choice to reduce the impact, but that require to
|
| 29 |
-
// change our current allocator mgr to use shared_ptr. Will revisit it
|
| 30 |
-
// later.
|
| 31 |
-
AllocatorPtr alloc_{nullptr};
|
| 32 |
-
};
|
| 33 |
-
|
| 34 |
-
using BufferUniquePtr = std::unique_ptr<void, BufferDeleter>;
|
| 35 |
-
using BufferNakedPtr = void*;
|
| 36 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/customregistry.h
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "core/common/status.h"
|
| 7 |
-
#include "core/common/logging/logging.h"
|
| 8 |
-
#include "core/framework/op_kernel.h"
|
| 9 |
-
#include "core/framework/kernel_def_builder.h"
|
| 10 |
-
#include "core/framework/kernel_registry.h"
|
| 11 |
-
|
| 12 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 13 |
-
#include "core/graph/schema_registry.h"
|
| 14 |
-
#endif
|
| 15 |
-
|
| 16 |
-
namespace onnxruntime {
|
| 17 |
-
|
| 18 |
-
/**
|
| 19 |
-
Represents a registry that contains both custom kernels and custom schemas.
|
| 20 |
-
*/
|
| 21 |
-
class CustomRegistry final {
|
| 22 |
-
public:
|
| 23 |
-
CustomRegistry()
|
| 24 |
-
: kernel_registry_(std::make_shared<KernelRegistry>())
|
| 25 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 26 |
-
,
|
| 27 |
-
opschema_registry_(std::make_shared<onnxruntime::OnnxRuntimeOpSchemaRegistry>())
|
| 28 |
-
#endif
|
| 29 |
-
{
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
/**
|
| 33 |
-
* Register a kernel definition together with kernel factory method to this session.
|
| 34 |
-
* If any conflict happened between registered kernel def and built-in kernel def,
|
| 35 |
-
* registered kernel will have higher priority.
|
| 36 |
-
* Call this before invoking Initialize().
|
| 37 |
-
* @return OK if success.
|
| 38 |
-
*/
|
| 39 |
-
common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
|
| 40 |
-
|
| 41 |
-
common::Status RegisterCustomKernel(KernelCreateInfo&);
|
| 42 |
-
|
| 43 |
-
const std::shared_ptr<KernelRegistry>& GetKernelRegistry();
|
| 44 |
-
|
| 45 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 46 |
-
common::Status RegisterOpSet(std::vector<ONNX_NAMESPACE::OpSchema>& schemas, const std::string& domain,
|
| 47 |
-
int baseline_opset_version, int opset_version);
|
| 48 |
-
|
| 49 |
-
const std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry>& GetOpschemaRegistry();
|
| 50 |
-
#endif
|
| 51 |
-
|
| 52 |
-
private:
|
| 53 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry);
|
| 54 |
-
std::shared_ptr<KernelRegistry> kernel_registry_;
|
| 55 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 56 |
-
std::shared_ptr<onnxruntime::OnnxRuntimeOpSchemaRegistry> opschema_registry_;
|
| 57 |
-
#endif
|
| 58 |
-
};
|
| 59 |
-
|
| 60 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/data_types.h
DELETED
|
@@ -1,1062 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <cstdint>
|
| 7 |
-
#include <cstring>
|
| 8 |
-
#include <string>
|
| 9 |
-
#include <type_traits>
|
| 10 |
-
#include <map>
|
| 11 |
-
#include <unordered_map>
|
| 12 |
-
#include "core/common/gsl.h"
|
| 13 |
-
#include "core/common/common.h"
|
| 14 |
-
#include "core/common/exceptions.h"
|
| 15 |
-
#include "core/framework/endian.h"
|
| 16 |
-
#include "core/framework/float16.h"
|
| 17 |
-
#include "core/framework/to_tensor_proto_element_type.h"
|
| 18 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 19 |
-
#include "onnx/defs/schema.h"
|
| 20 |
-
#else
|
| 21 |
-
#include "onnx/defs/data_type_utils.h"
|
| 22 |
-
#endif
|
| 23 |
-
#include "onnx/onnx_pb.h"
|
| 24 |
-
#include "onnx/onnx-operators_pb.h"
|
| 25 |
-
|
| 26 |
-
struct OrtValue;
|
| 27 |
-
|
| 28 |
-
namespace ONNX_NAMESPACE {
|
| 29 |
-
class TypeProto;
|
| 30 |
-
} // namespace ONNX_NAMESPACE
|
| 31 |
-
|
| 32 |
-
namespace onnxruntime {
|
| 33 |
-
/// Predefined registered types
|
| 34 |
-
|
| 35 |
-
#if !defined(DISABLE_ML_OPS)
|
| 36 |
-
|
| 37 |
-
// maps (only used by ML ops)
|
| 38 |
-
using MapStringToString = std::map<std::string, std::string>;
|
| 39 |
-
using MapStringToInt64 = std::map<std::string, int64_t>;
|
| 40 |
-
using MapStringToFloat = std::map<std::string, float>;
|
| 41 |
-
using MapStringToDouble = std::map<std::string, double>;
|
| 42 |
-
using MapInt64ToString = std::map<int64_t, std::string>;
|
| 43 |
-
using MapInt64ToInt64 = std::map<int64_t, int64_t>;
|
| 44 |
-
using MapInt64ToFloat = std::map<int64_t, float>;
|
| 45 |
-
using MapInt64ToDouble = std::map<int64_t, double>;
|
| 46 |
-
|
| 47 |
-
// vectors/sequences
|
| 48 |
-
using VectorMapStringToFloat = std::vector<MapStringToFloat>;
|
| 49 |
-
using VectorMapInt64ToFloat = std::vector<MapInt64ToFloat>;
|
| 50 |
-
|
| 51 |
-
#endif
|
| 52 |
-
|
| 53 |
-
using VectorString = std::vector<std::string>;
|
| 54 |
-
using VectorInt64 = std::vector<int64_t>;
|
| 55 |
-
|
| 56 |
-
// Forward declarations
|
| 57 |
-
class DataTypeImpl;
|
| 58 |
-
class TensorTypeBase;
|
| 59 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 60 |
-
class SparseTensorTypeBase;
|
| 61 |
-
#endif
|
| 62 |
-
class SequenceTensorTypeBase;
|
| 63 |
-
class NonTensorTypeBase;
|
| 64 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 65 |
-
class OptionalTypeBase;
|
| 66 |
-
#endif
|
| 67 |
-
class PrimitiveDataTypeBase;
|
| 68 |
-
class Tensor;
|
| 69 |
-
class TensorSeq;
|
| 70 |
-
|
| 71 |
-
// DataTypeImpl pointer as unique DataTypeImpl identifier.
|
| 72 |
-
using MLDataType = const DataTypeImpl*;
|
| 73 |
-
// be used with class MLValue
|
| 74 |
-
using DeleteFunc = void (*)(void*);
|
| 75 |
-
using CreateFunc = void* (*)();
|
| 76 |
-
|
| 77 |
-
/**
|
| 78 |
-
* \brief Base class for MLDataType
|
| 79 |
-
*
|
| 80 |
-
*/
|
| 81 |
-
class DataTypeImpl {
|
| 82 |
-
public:
|
| 83 |
-
enum class GeneralType {
|
| 84 |
-
kInvalid = 0,
|
| 85 |
-
kNonTensor = 1,
|
| 86 |
-
kTensor = 2,
|
| 87 |
-
kTensorSequence = 3,
|
| 88 |
-
kSparseTensor = 4,
|
| 89 |
-
kOptional = 5,
|
| 90 |
-
kPrimitive = 6,
|
| 91 |
-
};
|
| 92 |
-
|
| 93 |
-
const GeneralType type_;
|
| 94 |
-
const size_t size_;
|
| 95 |
-
|
| 96 |
-
protected:
|
| 97 |
-
DataTypeImpl(GeneralType type, size_t size) : type_{type}, size_{size} {}
|
| 98 |
-
|
| 99 |
-
public:
|
| 100 |
-
virtual ~DataTypeImpl() = default;
|
| 101 |
-
|
| 102 |
-
/**
|
| 103 |
-
* \brief this API will be used to check type compatibility at runtime
|
| 104 |
-
*
|
| 105 |
-
* \param type_proto a TypeProto instance that is constructed for a specific type
|
| 106 |
-
* will be checked against a TypeProto instance contained within a corresponding
|
| 107 |
-
* MLDataType instance.
|
| 108 |
-
*/
|
| 109 |
-
virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0;
|
| 110 |
-
|
| 111 |
-
size_t Size() const { return size_; }
|
| 112 |
-
|
| 113 |
-
virtual DeleteFunc GetDeleteFunc() const = 0;
|
| 114 |
-
|
| 115 |
-
/**
|
| 116 |
-
* \brief Retrieves an instance of TypeProto for
|
| 117 |
-
* a given MLDataType
|
| 118 |
-
* \returns optional TypeProto. Only ONNX types
|
| 119 |
-
has type proto, non-ONNX types will return nullptr.
|
| 120 |
-
*/
|
| 121 |
-
virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0;
|
| 122 |
-
|
| 123 |
-
bool IsTensorType() const {
|
| 124 |
-
return type_ == GeneralType::kTensor;
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
bool IsTensorSequenceType() const {
|
| 128 |
-
return type_ == GeneralType::kTensorSequence;
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
bool IsSparseTensorType() const {
|
| 132 |
-
return type_ == GeneralType::kSparseTensor;
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
bool IsOptionalType() const {
|
| 136 |
-
return type_ == GeneralType::kOptional;
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
bool IsNonTensorType() const {
|
| 140 |
-
return type_ == GeneralType::kNonTensor;
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
bool IsPrimitiveDataType() const {
|
| 144 |
-
return type_ == GeneralType::kPrimitive;
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
// Returns this if this is of tensor-type and null otherwise
|
| 148 |
-
const TensorTypeBase* AsTensorType() const;
|
| 149 |
-
|
| 150 |
-
const SequenceTensorTypeBase* AsSequenceTensorType() const;
|
| 151 |
-
|
| 152 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 153 |
-
// Returns this if this is of sparse-tensor-type and null otherwise
|
| 154 |
-
const SparseTensorTypeBase* AsSparseTensorType() const;
|
| 155 |
-
#endif
|
| 156 |
-
|
| 157 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 158 |
-
const OptionalTypeBase* AsOptionalType() const;
|
| 159 |
-
#endif
|
| 160 |
-
|
| 161 |
-
const NonTensorTypeBase* AsNonTensorType() const;
|
| 162 |
-
|
| 163 |
-
// Returns this if this is one of the primitive data types (specialization of PrimitiveDataTypeBase)
|
| 164 |
-
// and null otherwise
|
| 165 |
-
const PrimitiveDataTypeBase* AsPrimitiveDataType() const;
|
| 166 |
-
|
| 167 |
-
// Return the type meta that we are using in the runtime.
|
| 168 |
-
template <typename T>
|
| 169 |
-
static MLDataType GetType();
|
| 170 |
-
|
| 171 |
-
// Return the types for a concrete tensor type, like Tensor_Float
|
| 172 |
-
template <typename elemT>
|
| 173 |
-
static MLDataType GetTensorType();
|
| 174 |
-
|
| 175 |
-
template <typename elemT>
|
| 176 |
-
static MLDataType GetSequenceTensorType();
|
| 177 |
-
|
| 178 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 179 |
-
// Return the MLDataType for a concrete sparse tensor type.
|
| 180 |
-
template <typename elemT>
|
| 181 |
-
static MLDataType GetSparseTensorType();
|
| 182 |
-
#endif
|
| 183 |
-
|
| 184 |
-
template <typename T, typename elemT>
|
| 185 |
-
static MLDataType GetOptionalType();
|
| 186 |
-
|
| 187 |
-
/**
|
| 188 |
-
* Convert an ONNX TypeProto to onnxruntime DataTypeImpl.
|
| 189 |
-
* However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back.
|
| 190 |
-
* Even though GetTypeProto() will not have the original information, it will still have enough to correctly
|
| 191 |
-
* map to MLDataType.
|
| 192 |
-
* \param proto
|
| 193 |
-
*/
|
| 194 |
-
static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto);
|
| 195 |
-
|
| 196 |
-
static const TensorTypeBase* TensorTypeFromONNXEnum(int type);
|
| 197 |
-
static const SequenceTensorTypeBase* SequenceTensorTypeFromONNXEnum(int type);
|
| 198 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 199 |
-
static const SparseTensorTypeBase* SparseTensorTypeFromONNXEnum(int type);
|
| 200 |
-
#endif
|
| 201 |
-
|
| 202 |
-
static const char* ToString(MLDataType type);
|
| 203 |
-
static std::vector<std::string> ToString(const std::vector<MLDataType>& types);
|
| 204 |
-
// Registers ONNX_NAMESPACE::DataType (internalized string) with
|
| 205 |
-
// MLDataType. DataType is produced by internalizing an instance of
|
| 206 |
-
// TypeProto contained within MLDataType
|
| 207 |
-
static void RegisterDataType(MLDataType);
|
| 208 |
-
static MLDataType GetDataType(const std::string&);
|
| 209 |
-
|
| 210 |
-
static const std::vector<MLDataType>& AllTensorTypes();
|
| 211 |
-
static const std::vector<MLDataType>& AllFixedSizeTensorTypes();
|
| 212 |
-
static const std::vector<MLDataType>& AllSequenceTensorTypes();
|
| 213 |
-
static const std::vector<MLDataType>& AllFixedSizeSequenceTensorTypes();
|
| 214 |
-
static const std::vector<MLDataType>& AllNumericTensorTypes();
|
| 215 |
-
static const std::vector<MLDataType>& AllIEEEFloatTensorTypes();
|
| 216 |
-
static const std::vector<MLDataType>& AllFixedSizeTensorExceptHalfTypes();
|
| 217 |
-
static const std::vector<MLDataType>& AllIEEEFloatTensorExceptHalfTypes();
|
| 218 |
-
static const std::vector<MLDataType>& AllTensorAndSequenceTensorTypes();
|
| 219 |
-
static const std::vector<MLDataType>& AllFixedSizeTensorAndSequenceTensorTypes();
|
| 220 |
-
static const std::vector<MLDataType>& AllOptionalTypes();
|
| 221 |
-
static const std::vector<MLDataType>& AllTensorAndSequenceTensorAndOptionalTypes();
|
| 222 |
-
};
|
| 223 |
-
|
| 224 |
-
std::ostream& operator<<(std::ostream& out, MLDataType data_type);
|
| 225 |
-
|
| 226 |
-
/*
|
| 227 |
-
* Type registration helpers
|
| 228 |
-
*/
|
| 229 |
-
namespace data_types_internal {
|
| 230 |
-
/// TensorType helpers
|
| 231 |
-
///
|
| 232 |
-
|
| 233 |
-
/// Is a given type on the list of types?
|
| 234 |
-
/// Accepts a list of types and the first argument is the type
|
| 235 |
-
/// We are checking if it is listed among those that follow
|
| 236 |
-
template <typename T, typename... Types>
|
| 237 |
-
struct IsAnyOf;
|
| 238 |
-
|
| 239 |
-
/// Two types remaining, end of the list
|
| 240 |
-
template <typename T, typename Tail>
|
| 241 |
-
struct IsAnyOf<T, Tail> : public std::is_same<T, Tail> {
|
| 242 |
-
};
|
| 243 |
-
|
| 244 |
-
template <typename T, typename H, typename... Tail>
|
| 245 |
-
struct IsAnyOf<T, H, Tail...> {
|
| 246 |
-
static constexpr bool value = (std::is_same<T, H>::value ||
|
| 247 |
-
IsAnyOf<T, Tail...>::value);
|
| 248 |
-
};
|
| 249 |
-
|
| 250 |
-
/// Tells if the specified type is one of fundamental types
|
| 251 |
-
/// that can be contained within a tensor.
|
| 252 |
-
/// We do not have raw fundamental types, rather a subset
|
| 253 |
-
/// of fundamental types is contained within tensors.
|
| 254 |
-
template <typename T>
|
| 255 |
-
struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
|
| 256 |
-
int32_t, int64_t, std::string, bool, MLFloat16,
|
| 257 |
-
double, uint32_t, uint64_t, BFloat16> {
|
| 258 |
-
};
|
| 259 |
-
|
| 260 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 261 |
-
/// Use "IsSparseTensorContainedType<T>::value" to test if a type T
|
| 262 |
-
/// is permitted as the element-type of a sparse-tensor.
|
| 263 |
-
|
| 264 |
-
template <typename T>
|
| 265 |
-
struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
|
| 266 |
-
int32_t, int64_t, std::string, bool, MLFloat16,
|
| 267 |
-
double, uint32_t, uint64_t, BFloat16> {
|
| 268 |
-
};
|
| 269 |
-
#endif
|
| 270 |
-
|
| 271 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 272 |
-
/// Tells if the specified type is one of ORT types
|
| 273 |
-
/// that can be contained within an optional struct.
|
| 274 |
-
template <typename T>
|
| 275 |
-
struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
|
| 276 |
-
};
|
| 277 |
-
#endif
|
| 278 |
-
|
| 279 |
-
/// This template's Get() returns a corresponding MLDataType
|
| 280 |
-
/// It dispatches the call to either GetTensorType<>() or
|
| 281 |
-
/// GetType<>()
|
| 282 |
-
template <typename T, bool TensorContainedType>
|
| 283 |
-
struct GetMLDataType;
|
| 284 |
-
|
| 285 |
-
template <typename T>
|
| 286 |
-
struct GetMLDataType<T, true> {
|
| 287 |
-
static MLDataType Get() {
|
| 288 |
-
return DataTypeImpl::GetTensorType<T>();
|
| 289 |
-
}
|
| 290 |
-
};
|
| 291 |
-
|
| 292 |
-
template <typename T>
|
| 293 |
-
struct GetMLDataType<T, false> {
|
| 294 |
-
static MLDataType Get() {
|
| 295 |
-
return DataTypeImpl::GetType<T>();
|
| 296 |
-
}
|
| 297 |
-
};
|
| 298 |
-
|
| 299 |
-
struct TensorTypeHelper {
|
| 300 |
-
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
|
| 301 |
-
ONNX_NAMESPACE::TypeProto& proto) {
|
| 302 |
-
proto.mutable_tensor_type()->set_elem_type(element_type);
|
| 303 |
-
}
|
| 304 |
-
};
|
| 305 |
-
|
| 306 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 307 |
-
struct SparseTensorTypeHelper {
|
| 308 |
-
static void Set(ONNX_NAMESPACE::TensorProto_DataType element_type,
|
| 309 |
-
ONNX_NAMESPACE::TypeProto& proto) {
|
| 310 |
-
proto.mutable_sparse_tensor_type()->set_elem_type(element_type);
|
| 311 |
-
}
|
| 312 |
-
};
|
| 313 |
-
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
| 314 |
-
|
| 315 |
-
#if !defined(DISABLE_ML_OPS)
|
| 316 |
-
/// Map helpers
|
| 317 |
-
|
| 318 |
-
void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&,
|
| 319 |
-
ONNX_NAMESPACE::TypeProto&);
|
| 320 |
-
|
| 321 |
-
struct MapTypeHelper {
|
| 322 |
-
// V can be either a primitive type (in which case it is a tensor)
|
| 323 |
-
// or other preregistered types
|
| 324 |
-
template <typename V>
|
| 325 |
-
static MLDataType GetValueType() {
|
| 326 |
-
return GetMLDataType<V, IsTensorContainedType<V>::value>::Get();
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
static void Set(ONNX_NAMESPACE::TensorProto_DataType key_type, const ONNX_NAMESPACE::TypeProto* value_proto,
|
| 330 |
-
ONNX_NAMESPACE::TypeProto& proto) {
|
| 331 |
-
ORT_ENFORCE(value_proto != nullptr, "expected a registered ONNX type");
|
| 332 |
-
proto.mutable_map_type()->set_key_type(key_type);
|
| 333 |
-
CopyMutableMapValue(*value_proto, proto);
|
| 334 |
-
}
|
| 335 |
-
};
|
| 336 |
-
#endif
|
| 337 |
-
|
| 338 |
-
/// Sequence helpers
|
| 339 |
-
|
| 340 |
-
// Element type is a primitive type so we set it to a tensor<elemT>
|
| 341 |
-
void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&,
|
| 342 |
-
ONNX_NAMESPACE::TypeProto&);
|
| 343 |
-
|
| 344 |
-
// helper to create TypeProto with minimal binary size impact
|
| 345 |
-
struct SequenceTypeHelper {
|
| 346 |
-
template <typename T>
|
| 347 |
-
static MLDataType GetElemType() {
|
| 348 |
-
return GetMLDataType<T, IsTensorContainedType<T>::value>::Get();
|
| 349 |
-
}
|
| 350 |
-
|
| 351 |
-
static void Set(const ONNX_NAMESPACE::TypeProto* elem_proto,
|
| 352 |
-
ONNX_NAMESPACE::TypeProto& proto) {
|
| 353 |
-
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
|
| 354 |
-
CopyMutableSeqElement(*elem_proto, proto);
|
| 355 |
-
}
|
| 356 |
-
};
|
| 357 |
-
|
| 358 |
-
/// Optional helpers
|
| 359 |
-
|
| 360 |
-
void CopyMutableOptionalElement(const ONNX_NAMESPACE::TypeProto&,
|
| 361 |
-
ONNX_NAMESPACE::TypeProto&);
|
| 362 |
-
|
| 363 |
-
// helper to create TypeProto with minimal binary size impact
|
| 364 |
-
struct OptionalTypeHelper {
|
| 365 |
-
template <typename T, typename elemT>
|
| 366 |
-
static MLDataType GetElemType() {
|
| 367 |
-
if constexpr (std::is_same<T, Tensor>::value) {
|
| 368 |
-
return DataTypeImpl::GetTensorType<elemT>();
|
| 369 |
-
} else {
|
| 370 |
-
static_assert(std::is_same<T, TensorSeq>::value, "Unsupported element type for optional type");
|
| 371 |
-
return DataTypeImpl::GetSequenceTensorType<elemT>();
|
| 372 |
-
}
|
| 373 |
-
}
|
| 374 |
-
|
| 375 |
-
static void Set(const onnx::TypeProto* elem_proto, ONNX_NAMESPACE::TypeProto& proto) {
|
| 376 |
-
ORT_ENFORCE(elem_proto != nullptr, "expected a registered ONNX type");
|
| 377 |
-
CopyMutableOptionalElement(*elem_proto, proto);
|
| 378 |
-
}
|
| 379 |
-
};
|
| 380 |
-
|
| 381 |
-
/// OpaqueTypes helpers
|
| 382 |
-
|
| 383 |
-
void AssignOpaqueDomainName(const char* domain, const char* name,
|
| 384 |
-
ONNX_NAMESPACE::TypeProto& proto);
|
| 385 |
-
|
| 386 |
-
} // namespace data_types_internal
|
| 387 |
-
|
| 388 |
-
//The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
|
| 389 |
-
//However, we do not allocate this type on heap.
|
| 390 |
-
#if defined(_MSC_VER) && !defined(__clang__)
|
| 391 |
-
#pragma warning(push)
|
| 392 |
-
#pragma warning(disable : 26436)
|
| 393 |
-
#endif
|
| 394 |
-
/// All tensors base
|
| 395 |
-
class TensorTypeBase : public DataTypeImpl {
|
| 396 |
-
public:
|
| 397 |
-
static MLDataType Type();
|
| 398 |
-
|
| 399 |
-
/// We first compare type_proto pointers and then
|
| 400 |
-
/// if they do not match try to account for the case
|
| 401 |
-
/// where TypeProto was created ad-hoc and not queried from MLDataType
|
| 402 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
| 403 |
-
|
| 404 |
-
DeleteFunc GetDeleteFunc() const override;
|
| 405 |
-
|
| 406 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 407 |
-
|
| 408 |
-
virtual MLDataType GetElementType() const {
|
| 409 |
-
// should never reach here.
|
| 410 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 411 |
-
}
|
| 412 |
-
|
| 413 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorTypeBase);
|
| 414 |
-
|
| 415 |
-
protected:
|
| 416 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 417 |
-
|
| 418 |
-
TensorTypeBase();
|
| 419 |
-
~TensorTypeBase() override;
|
| 420 |
-
|
| 421 |
-
private:
|
| 422 |
-
struct Impl;
|
| 423 |
-
Impl* impl_;
|
| 424 |
-
};
|
| 425 |
-
|
| 426 |
-
/**
|
| 427 |
-
* \brief Tensor type. This type does not have a C++ type associated with
|
| 428 |
-
* it at registration time except the element type. One of the types mentioned
|
| 429 |
-
* above at IsTensorContainedType<> list is acceptable.
|
| 430 |
-
*
|
| 431 |
-
* \details
|
| 432 |
-
* Usage:
|
| 433 |
-
* ORT_REGISTER_TENSOR(ELEMENT_TYPE)
|
| 434 |
-
* Currently all of the Tensors irrespective of the dimensions are mapped to Tensor<type>
|
| 435 |
-
* type. IsCompatible() currently ignores shape.
|
| 436 |
-
*/
|
| 437 |
-
|
| 438 |
-
template <typename elemT>
|
| 439 |
-
class TensorType : public TensorTypeBase {
|
| 440 |
-
public:
|
| 441 |
-
static_assert(data_types_internal::IsTensorContainedType<elemT>::value,
|
| 442 |
-
"Requires one of the tensor fundamental types");
|
| 443 |
-
|
| 444 |
-
static MLDataType Type();
|
| 445 |
-
|
| 446 |
-
/// Tensors only can contain basic data types
|
| 447 |
-
/// that have been previously registered with ONNXRuntime
|
| 448 |
-
MLDataType GetElementType() const override {
|
| 449 |
-
return DataTypeImpl::GetType<elemT>();
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
private:
|
| 453 |
-
TensorType() {
|
| 454 |
-
using namespace data_types_internal;
|
| 455 |
-
TensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
|
| 456 |
-
}
|
| 457 |
-
};
|
| 458 |
-
|
| 459 |
-
#if defined(DISABLE_OPTIONAL_TYPE)
|
| 460 |
-
|
| 461 |
-
// TODO is this still needed after removing kernel def hashes?
|
| 462 |
-
/// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
|
| 463 |
-
/// with disabled types to keep the ORT format model kernel hashes stable.
|
| 464 |
-
class DisabledTypeBase : public DataTypeImpl {
|
| 465 |
-
public:
|
| 466 |
-
static MLDataType Type();
|
| 467 |
-
|
| 468 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
|
| 469 |
-
// We always want to return false for the IsCompatible() for a disabled type
|
| 470 |
-
// because this will ensure that no kernel supporting the disabled type will
|
| 471 |
-
// be matched to a model node requiring that type and the model load will
|
| 472 |
-
// result in failure.
|
| 473 |
-
return false;
|
| 474 |
-
}
|
| 475 |
-
|
| 476 |
-
DeleteFunc GetDeleteFunc() const override {
|
| 477 |
-
ORT_THROW("Type is disabled in this build.");
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
// This must work
|
| 481 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 482 |
-
|
| 483 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);
|
| 484 |
-
|
| 485 |
-
protected:
|
| 486 |
-
// This must work
|
| 487 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 488 |
-
|
| 489 |
-
DisabledTypeBase(DataTypeImpl::GeneralType type, size_t size);
|
| 490 |
-
~DisabledTypeBase() override;
|
| 491 |
-
|
| 492 |
-
private:
|
| 493 |
-
struct Impl;
|
| 494 |
-
Impl* impl_;
|
| 495 |
-
};
|
| 496 |
-
|
| 497 |
-
#endif
|
| 498 |
-
|
| 499 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 500 |
-
/// Common base-class for all sparse-tensors (with different element types).
|
| 501 |
-
class SparseTensorTypeBase : public DataTypeImpl {
|
| 502 |
-
public:
|
| 503 |
-
static MLDataType Type();
|
| 504 |
-
|
| 505 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
| 506 |
-
|
| 507 |
-
DeleteFunc GetDeleteFunc() const override;
|
| 508 |
-
|
| 509 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 510 |
-
|
| 511 |
-
virtual MLDataType GetElementType() const {
|
| 512 |
-
// should never reach here.
|
| 513 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 514 |
-
}
|
| 515 |
-
|
| 516 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SparseTensorTypeBase);
|
| 517 |
-
|
| 518 |
-
protected:
|
| 519 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 520 |
-
|
| 521 |
-
SparseTensorTypeBase();
|
| 522 |
-
~SparseTensorTypeBase() override;
|
| 523 |
-
|
| 524 |
-
private:
|
| 525 |
-
struct Impl;
|
| 526 |
-
Impl* impl_;
|
| 527 |
-
};
|
| 528 |
-
|
| 529 |
-
template <typename elemT>
|
| 530 |
-
class SparseTensorType : public SparseTensorTypeBase {
|
| 531 |
-
public:
|
| 532 |
-
static_assert(data_types_internal::IsSparseTensorContainedType<elemT>::value,
|
| 533 |
-
"Requires one of the sparse-tensor fundamental types");
|
| 534 |
-
|
| 535 |
-
static MLDataType Type();
|
| 536 |
-
|
| 537 |
-
/// Return a MLDataType representing the element-type
|
| 538 |
-
MLDataType GetElementType() const override {
|
| 539 |
-
return DataTypeImpl::GetType<elemT>();
|
| 540 |
-
}
|
| 541 |
-
|
| 542 |
-
private:
|
| 543 |
-
SparseTensorType() {
|
| 544 |
-
using namespace data_types_internal;
|
| 545 |
-
SparseTensorTypeHelper::Set(utils::ToTensorProtoElementType<elemT>(), MutableTypeProto());
|
| 546 |
-
}
|
| 547 |
-
};
|
| 548 |
-
|
| 549 |
-
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
| 550 |
-
|
| 551 |
-
/// Common base-class for all optional types.
|
| 552 |
-
|
| 553 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 554 |
-
class OptionalTypeBase : public DataTypeImpl {
|
| 555 |
-
public:
|
| 556 |
-
static MLDataType Type();
|
| 557 |
-
|
| 558 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
| 559 |
-
|
| 560 |
-
DeleteFunc GetDeleteFunc() const override {
|
| 561 |
-
// should never reach here.
|
| 562 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 563 |
-
}
|
| 564 |
-
|
| 565 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 566 |
-
|
| 567 |
-
virtual MLDataType GetElementType() const {
|
| 568 |
-
// should never reach here.
|
| 569 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 570 |
-
}
|
| 571 |
-
|
| 572 |
-
OptionalTypeBase(const OptionalTypeBase&) = delete;
|
| 573 |
-
OptionalTypeBase& operator=(const OptionalTypeBase&) = delete;
|
| 574 |
-
|
| 575 |
-
protected:
|
| 576 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 577 |
-
|
| 578 |
-
OptionalTypeBase();
|
| 579 |
-
~OptionalTypeBase() override;
|
| 580 |
-
|
| 581 |
-
private:
|
| 582 |
-
struct Impl;
|
| 583 |
-
Impl* impl_;
|
| 584 |
-
};
|
| 585 |
-
#endif
|
| 586 |
-
|
| 587 |
-
// Derive from OptionalTypeBase if the Optional type support is enabled,
|
| 588 |
-
// else derive from DisabledTypeBase
|
| 589 |
-
template <typename T, typename elemT>
|
| 590 |
-
class OptionalType :
|
| 591 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 592 |
-
public OptionalTypeBase
|
| 593 |
-
#else
|
| 594 |
-
public DisabledTypeBase
|
| 595 |
-
#endif
|
| 596 |
-
{
|
| 597 |
-
public:
|
| 598 |
-
static MLDataType Type();
|
| 599 |
-
|
| 600 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 601 |
-
static_assert(data_types_internal::IsOptionalOrtType<T>::value,
|
| 602 |
-
"Requires one of the supported types: Tensor or TensorSeq");
|
| 603 |
-
|
| 604 |
-
static_assert(data_types_internal::IsTensorContainedType<elemT>::value,
|
| 605 |
-
"Requires one of the tensor fundamental types");
|
| 606 |
-
|
| 607 |
-
MLDataType GetElementType() const override {
|
| 608 |
-
return data_types_internal::OptionalTypeHelper::GetElemType<T, elemT>();
|
| 609 |
-
}
|
| 610 |
-
#endif
|
| 611 |
-
|
| 612 |
-
private:
|
| 613 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 614 |
-
OptionalType()
|
| 615 |
-
#else
|
| 616 |
-
OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
|
| 617 |
-
#endif
|
| 618 |
-
{
|
| 619 |
-
using namespace data_types_internal;
|
| 620 |
-
OptionalTypeHelper::Set(OptionalTypeHelper::GetElemType<T, elemT>()->GetTypeProto(), MutableTypeProto());
|
| 621 |
-
}
|
| 622 |
-
}; // namespace onnxruntime
|
| 623 |
-
|
| 624 |
-
/**
|
| 625 |
-
* \brief Provide a specialization for your C++ Non-tensor type
|
| 626 |
-
* so your implementation FromDataTypeContainer/ToDataTypeContainer
|
| 627 |
-
* functions correctly. Otherwise you get a default implementation
|
| 628 |
-
* which may not be what you need/want.
|
| 629 |
-
*
|
| 630 |
-
* This class is used to create OrtValue, fetch data from OrtValue via
|
| 631 |
-
* C/C++ APIs
|
| 632 |
-
*/
|
| 633 |
-
template <class T>
|
| 634 |
-
struct NonTensorTypeConverter {
|
| 635 |
-
static void FromContainer(MLDataType /*dtype*/, const void* /*data*/, size_t /*data_size*/, OrtValue& /*output*/) {
|
| 636 |
-
ORT_THROW("Not implemented");
|
| 637 |
-
}
|
| 638 |
-
static void ToContainer(const OrtValue& /*input*/, size_t /*data_size*/, void* /*data*/) {
|
| 639 |
-
ORT_THROW("Not implemented");
|
| 640 |
-
}
|
| 641 |
-
};
|
| 642 |
-
|
| 643 |
-
/**
|
| 644 |
-
* \brief Base type for all non-tensors, maps, sequences and opaques
|
| 645 |
-
*/
|
| 646 |
-
class NonTensorTypeBase : public DataTypeImpl {
|
| 647 |
-
public:
|
| 648 |
-
DeleteFunc GetDeleteFunc() const override = 0;
|
| 649 |
-
|
| 650 |
-
virtual CreateFunc GetCreateFunc() const = 0;
|
| 651 |
-
|
| 652 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 653 |
-
|
| 654 |
-
// \brief Override for Non-tensor types to initialize non-tensor CPP
|
| 655 |
-
// data representation from data. The caller of the interface
|
| 656 |
-
// should have a shared definition of the data which is used to initialize
|
| 657 |
-
// CPP data representation. This is used from C API.
|
| 658 |
-
//
|
| 659 |
-
// \param data - pointer to a data container structure non_tensor type specific
|
| 660 |
-
// \param data_size - size of the data container structure, used for rudimentary checks
|
| 661 |
-
// \param output - reference to a default constructed non-tensor type
|
| 662 |
-
// \returns OrtValue
|
| 663 |
-
// \throw if there is an error
|
| 664 |
-
virtual void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const;
|
| 665 |
-
|
| 666 |
-
// \brief Override for Non-tensor types to fetch data from the internal CPP data representation
|
| 667 |
-
// The caller of the interface should have a shared definition of the data which is used to initialize
|
| 668 |
-
// CPP data representation. This is used from C API.
|
| 669 |
-
//
|
| 670 |
-
// \param input - OrtValue containing data
|
| 671 |
-
// \param data_size - size of the structure that is being passed for receiving data, used for
|
| 672 |
-
// validation
|
| 673 |
-
// \param data - pointer to receiving data structure
|
| 674 |
-
virtual void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const;
|
| 675 |
-
|
| 676 |
-
NonTensorTypeBase(const NonTensorTypeBase&) = delete;
|
| 677 |
-
NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete;
|
| 678 |
-
|
| 679 |
-
protected:
|
| 680 |
-
NonTensorTypeBase(size_t size);
|
| 681 |
-
~NonTensorTypeBase() override;
|
| 682 |
-
|
| 683 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 684 |
-
|
| 685 |
-
bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
|
| 686 |
-
|
| 687 |
-
bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
|
| 688 |
-
|
| 689 |
-
bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const;
|
| 690 |
-
|
| 691 |
-
private:
|
| 692 |
-
struct Impl;
|
| 693 |
-
Impl* impl_;
|
| 694 |
-
};
|
| 695 |
-
|
| 696 |
-
// This is where T is the actual CPPRuntimeType
|
| 697 |
-
template <typename T>
|
| 698 |
-
class NonTensorType : public NonTensorTypeBase {
|
| 699 |
-
private:
|
| 700 |
-
static void Delete(void* p) {
|
| 701 |
-
delete static_cast<T*>(p);
|
| 702 |
-
}
|
| 703 |
-
|
| 704 |
-
public:
|
| 705 |
-
DeleteFunc GetDeleteFunc() const override {
|
| 706 |
-
return &Delete;
|
| 707 |
-
}
|
| 708 |
-
|
| 709 |
-
CreateFunc GetCreateFunc() const override {
|
| 710 |
-
return []() -> void* { return new T(); };
|
| 711 |
-
}
|
| 712 |
-
|
| 713 |
-
protected:
|
| 714 |
-
NonTensorType() : NonTensorTypeBase(sizeof(T)) {}
|
| 715 |
-
};
|
| 716 |
-
|
| 717 |
-
#if !defined(DISABLE_ML_OPS)
|
| 718 |
-
/**
|
| 719 |
-
* \brief MapType. Use this type to register
|
| 720 |
-
* mapping types.
|
| 721 |
-
*
|
| 722 |
-
* \param T - cpp type that you wish to register as runtime MapType
|
| 723 |
-
*
|
| 724 |
-
* \details Usage: ORT_REGISTER_MAP(C++Type)
|
| 725 |
-
* The type is required to have mapped_type and
|
| 726 |
-
* key_type defined
|
| 727 |
-
*/
|
| 728 |
-
template <typename CPPType>
|
| 729 |
-
class MapType : public NonTensorType<CPPType> {
|
| 730 |
-
public:
|
| 731 |
-
static_assert(data_types_internal::IsTensorContainedType<typename CPPType::key_type>::value,
|
| 732 |
-
"Requires one of the tensor fundamental types as key");
|
| 733 |
-
|
| 734 |
-
static MLDataType Type();
|
| 735 |
-
|
| 736 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
|
| 737 |
-
return this->IsMapCompatible(type_proto);
|
| 738 |
-
}
|
| 739 |
-
|
| 740 |
-
private:
|
| 741 |
-
MapType() {
|
| 742 |
-
using namespace data_types_internal;
|
| 743 |
-
MapTypeHelper::Set(utils::ToTensorProtoElementType<typename CPPType::key_type>(),
|
| 744 |
-
MapTypeHelper::GetValueType<typename CPPType::mapped_type>()->GetTypeProto(),
|
| 745 |
-
this->MutableTypeProto());
|
| 746 |
-
}
|
| 747 |
-
};
|
| 748 |
-
#endif
|
| 749 |
-
|
| 750 |
-
/**
|
| 751 |
-
* \brief SequenceType. Use to register sequence for non-tensor types.
|
| 752 |
-
*
|
| 753 |
-
* \param T - CPP type that you wish to register as Sequence
|
| 754 |
-
* runtime type.
|
| 755 |
-
*
|
| 756 |
-
* \details Usage: ORT_REGISTER_SEQ(C++Type)
|
| 757 |
-
* The type is required to have value_type defined
|
| 758 |
-
*/
|
| 759 |
-
template <typename CPPType>
|
| 760 |
-
class SequenceType : public NonTensorType<CPPType> {
|
| 761 |
-
public:
|
| 762 |
-
static MLDataType Type();
|
| 763 |
-
|
| 764 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
|
| 765 |
-
return this->IsSequenceCompatible(type_proto);
|
| 766 |
-
}
|
| 767 |
-
|
| 768 |
-
private:
|
| 769 |
-
SequenceType() {
|
| 770 |
-
using namespace data_types_internal;
|
| 771 |
-
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<typename CPPType::value_type>()->GetTypeProto(),
|
| 772 |
-
this->MutableTypeProto());
|
| 773 |
-
}
|
| 774 |
-
};
|
| 775 |
-
|
| 776 |
-
/**
|
| 777 |
-
* \brief SequenceTensorTypeBase serves as a base type class for
|
| 778 |
-
* Tensor sequences. Akin to TensorTypeBase.
|
| 779 |
-
* Runtime representation is always TensorSeq.
|
| 780 |
-
*/
|
| 781 |
-
class SequenceTensorTypeBase : public DataTypeImpl {
|
| 782 |
-
public:
|
| 783 |
-
static MLDataType Type();
|
| 784 |
-
|
| 785 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override;
|
| 786 |
-
|
| 787 |
-
virtual MLDataType GetElementType() const {
|
| 788 |
-
// should never reach here.
|
| 789 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 790 |
-
}
|
| 791 |
-
|
| 792 |
-
DeleteFunc GetDeleteFunc() const override;
|
| 793 |
-
|
| 794 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;
|
| 795 |
-
|
| 796 |
-
SequenceTensorTypeBase(const SequenceTensorTypeBase&) = delete;
|
| 797 |
-
SequenceTensorTypeBase& operator=(const SequenceTensorTypeBase&) = delete;
|
| 798 |
-
|
| 799 |
-
protected:
|
| 800 |
-
SequenceTensorTypeBase();
|
| 801 |
-
~SequenceTensorTypeBase();
|
| 802 |
-
|
| 803 |
-
ONNX_NAMESPACE::TypeProto& MutableTypeProto();
|
| 804 |
-
|
| 805 |
-
private:
|
| 806 |
-
struct Impl;
|
| 807 |
-
Impl* impl_;
|
| 808 |
-
};
|
| 809 |
-
#if defined(_MSC_VER) && !defined(__clang__)
|
| 810 |
-
#pragma warning(pop)
|
| 811 |
-
#endif
|
| 812 |
-
/**
|
| 813 |
-
* \brief SequenceTensorType. Use to register sequence for non-tensor types.
|
| 814 |
-
*
|
| 815 |
-
* \param CPPRuntime - We always use TensorSeq
|
| 816 |
-
*
|
| 817 |
-
* \param TensorElemType - one of the primitive types
|
| 818 |
-
*
|
| 819 |
-
* \details Usage: ORT_REGISTER_SEQ_TENSOR_TYPE()
|
| 820 |
-
* The type is required to have value_type defined
|
| 821 |
-
*/
|
| 822 |
-
template <typename TensorElemType>
|
| 823 |
-
class SequenceTensorType : public SequenceTensorTypeBase {
|
| 824 |
-
public:
|
| 825 |
-
static_assert(data_types_internal::IsTensorContainedType<TensorElemType>::value,
|
| 826 |
-
"Requires one of the tensor fundamental types");
|
| 827 |
-
|
| 828 |
-
static MLDataType Type();
|
| 829 |
-
|
| 830 |
-
/// Return a MLDataType representing the element-type
|
| 831 |
-
MLDataType GetElementType() const override {
|
| 832 |
-
return DataTypeImpl::GetType<TensorElemType>();
|
| 833 |
-
}
|
| 834 |
-
|
| 835 |
-
private:
|
| 836 |
-
SequenceTensorType() {
|
| 837 |
-
using namespace data_types_internal;
|
| 838 |
-
SequenceTypeHelper::Set(SequenceTypeHelper::GetElemType<TensorElemType>()->GetTypeProto(),
|
| 839 |
-
MutableTypeProto());
|
| 840 |
-
}
|
| 841 |
-
};
|
| 842 |
-
|
| 843 |
-
/**
|
| 844 |
-
* \brief OpaqueType
|
| 845 |
-
*
|
| 846 |
-
* \tparam T - cpp runtume that implements the Opaque type
|
| 847 |
-
*
|
| 848 |
-
* \tparam const char D[] - domain must be extern to be unique
|
| 849 |
-
*
|
| 850 |
-
* \tparam const char N[] - name must be extern to be unique
|
| 851 |
-
*
|
| 852 |
-
* \details Only one CPP type can be associated with a particular
|
| 853 |
-
* OpaqueType registration
|
| 854 |
-
*
|
| 855 |
-
*/
|
| 856 |
-
template <typename T, const char D[], const char N[]>
|
| 857 |
-
class OpaqueType : public NonTensorType<T> {
|
| 858 |
-
public:
|
| 859 |
-
static MLDataType Type();
|
| 860 |
-
|
| 861 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override {
|
| 862 |
-
return this->IsOpaqueCompatible(type_proto);
|
| 863 |
-
}
|
| 864 |
-
|
| 865 |
-
void FromDataContainer(const void* data, size_t data_size, OrtValue& output) const override {
|
| 866 |
-
NonTensorTypeConverter<T>::FromContainer(this, data, data_size, output);
|
| 867 |
-
}
|
| 868 |
-
|
| 869 |
-
void ToDataContainer(const OrtValue& input, size_t data_size, void* data) const override {
|
| 870 |
-
NonTensorTypeConverter<T>::ToContainer(input, data_size, data);
|
| 871 |
-
}
|
| 872 |
-
|
| 873 |
-
private:
|
| 874 |
-
OpaqueType() {
|
| 875 |
-
data_types_internal::AssignOpaqueDomainName(D, N, this->MutableTypeProto());
|
| 876 |
-
}
|
| 877 |
-
};
|
| 878 |
-
|
| 879 |
-
/**
|
| 880 |
-
* \brief PrimitiveDataTypeBase
|
| 881 |
-
* Base class for primitive Tensor contained types
|
| 882 |
-
*
|
| 883 |
-
* \details This class contains an integer constant that can be
|
| 884 |
-
* used for input data type dispatching
|
| 885 |
-
*
|
| 886 |
-
*/
|
| 887 |
-
class PrimitiveDataTypeBase : public DataTypeImpl {
|
| 888 |
-
public:
|
| 889 |
-
bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
|
| 890 |
-
return false;
|
| 891 |
-
}
|
| 892 |
-
|
| 893 |
-
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final {
|
| 894 |
-
return nullptr;
|
| 895 |
-
}
|
| 896 |
-
|
| 897 |
-
int32_t GetDataType() const {
|
| 898 |
-
return data_type_;
|
| 899 |
-
}
|
| 900 |
-
|
| 901 |
-
protected:
|
| 902 |
-
PrimitiveDataTypeBase(size_t size, int32_t data_type)
|
| 903 |
-
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
|
| 904 |
-
|
| 905 |
-
private:
|
| 906 |
-
const int32_t data_type_;
|
| 907 |
-
};
|
| 908 |
-
|
| 909 |
-
/**
|
| 910 |
-
* \brief PrimitiveDataType
|
| 911 |
-
* Typed specialization for primitive types.
|
| 912 |
-
* Concrete instances of this class are used by Tensor.
|
| 913 |
-
*
|
| 914 |
-
* \param T - primitive data type
|
| 915 |
-
*
|
| 916 |
-
*/
|
| 917 |
-
template <typename T>
|
| 918 |
-
class PrimitiveDataType : public PrimitiveDataTypeBase {
|
| 919 |
-
private:
|
| 920 |
-
static void Delete(void* p) {
|
| 921 |
-
delete static_cast<T*>(p);
|
| 922 |
-
}
|
| 923 |
-
|
| 924 |
-
public:
|
| 925 |
-
static MLDataType Type();
|
| 926 |
-
|
| 927 |
-
DeleteFunc GetDeleteFunc() const override {
|
| 928 |
-
return &Delete;
|
| 929 |
-
}
|
| 930 |
-
|
| 931 |
-
private:
|
| 932 |
-
PrimitiveDataType()
|
| 933 |
-
: PrimitiveDataTypeBase{sizeof(T),
|
| 934 |
-
utils::ToTensorProtoElementType<T>()} {
|
| 935 |
-
}
|
| 936 |
-
};
|
| 937 |
-
|
| 938 |
-
inline const TensorTypeBase* DataTypeImpl::AsTensorType() const {
|
| 939 |
-
return IsTensorType() ? static_cast<const TensorTypeBase*>(this) : nullptr;
|
| 940 |
-
}
|
| 941 |
-
|
| 942 |
-
inline const SequenceTensorTypeBase* DataTypeImpl::AsSequenceTensorType() const {
|
| 943 |
-
return IsTensorSequenceType() ? static_cast<const SequenceTensorTypeBase*>(this) : nullptr;
|
| 944 |
-
}
|
| 945 |
-
|
| 946 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 947 |
-
inline const SparseTensorTypeBase* DataTypeImpl::AsSparseTensorType() const {
|
| 948 |
-
return IsSparseTensorType() ? static_cast<const SparseTensorTypeBase*>(this) : nullptr;
|
| 949 |
-
}
|
| 950 |
-
#endif
|
| 951 |
-
|
| 952 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 953 |
-
inline const OptionalTypeBase* DataTypeImpl::AsOptionalType() const {
|
| 954 |
-
return IsOptionalType() ? static_cast<const OptionalTypeBase*>(this) : nullptr;
|
| 955 |
-
}
|
| 956 |
-
#endif
|
| 957 |
-
|
| 958 |
-
inline const NonTensorTypeBase* DataTypeImpl::AsNonTensorType() const {
|
| 959 |
-
return IsNonTensorType() ? static_cast<const NonTensorTypeBase*>(this) : nullptr;
|
| 960 |
-
}
|
| 961 |
-
|
| 962 |
-
inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
|
| 963 |
-
return IsPrimitiveDataType() ? static_cast<const PrimitiveDataTypeBase*>(this) : nullptr;
|
| 964 |
-
}
|
| 965 |
-
|
| 966 |
-
// Explicit specialization of base class template function
|
| 967 |
-
// is only possible within the enclosing namespace scope,
|
| 968 |
-
// thus a simple way to pre-instantiate a given template
|
| 969 |
-
// at a registration time does not currently work and the macro
|
| 970 |
-
// is needed.
|
| 971 |
-
#define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \
|
| 972 |
-
template <> \
|
| 973 |
-
MLDataType TensorType<ELEM_TYPE>::Type() { \
|
| 974 |
-
static TensorType<ELEM_TYPE> tensor_type; \
|
| 975 |
-
return &tensor_type; \
|
| 976 |
-
} \
|
| 977 |
-
template <> \
|
| 978 |
-
MLDataType DataTypeImpl::GetTensorType<ELEM_TYPE>() { \
|
| 979 |
-
return TensorType<ELEM_TYPE>::Type(); \
|
| 980 |
-
}
|
| 981 |
-
|
| 982 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 983 |
-
#define ORT_REGISTER_SPARSE_TENSOR_TYPE(ELEM_TYPE) \
|
| 984 |
-
template <> \
|
| 985 |
-
MLDataType SparseTensorType<ELEM_TYPE>::Type() { \
|
| 986 |
-
static SparseTensorType<ELEM_TYPE> tensor_type; \
|
| 987 |
-
return &tensor_type; \
|
| 988 |
-
} \
|
| 989 |
-
template <> \
|
| 990 |
-
MLDataType DataTypeImpl::GetSparseTensorType<ELEM_TYPE>() { \
|
| 991 |
-
return SparseTensorType<ELEM_TYPE>::Type(); \
|
| 992 |
-
}
|
| 993 |
-
#endif
|
| 994 |
-
|
| 995 |
-
#define ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, TYPE) \
|
| 996 |
-
template <> \
|
| 997 |
-
MLDataType OptionalType<ORT_TYPE, TYPE>::Type() { \
|
| 998 |
-
static OptionalType<ORT_TYPE, TYPE> optional_type; \
|
| 999 |
-
return &optional_type; \
|
| 1000 |
-
} \
|
| 1001 |
-
template <> \
|
| 1002 |
-
MLDataType DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>() { \
|
| 1003 |
-
return OptionalType<ORT_TYPE, TYPE>::Type(); \
|
| 1004 |
-
}
|
| 1005 |
-
|
| 1006 |
-
#if !defined(DISABLE_ML_OPS)
|
| 1007 |
-
#define ORT_REGISTER_MAP(TYPE) \
|
| 1008 |
-
template <> \
|
| 1009 |
-
MLDataType MapType<TYPE>::Type() { \
|
| 1010 |
-
static MapType<TYPE> map_type; \
|
| 1011 |
-
return &map_type; \
|
| 1012 |
-
} \
|
| 1013 |
-
template <> \
|
| 1014 |
-
MLDataType DataTypeImpl::GetType<TYPE>() { \
|
| 1015 |
-
return MapType<TYPE>::Type(); \
|
| 1016 |
-
}
|
| 1017 |
-
#endif
|
| 1018 |
-
|
| 1019 |
-
#define ORT_REGISTER_SEQ(TYPE) \
|
| 1020 |
-
template <> \
|
| 1021 |
-
MLDataType SequenceType<TYPE>::Type() { \
|
| 1022 |
-
static SequenceType<TYPE> sequence_type; \
|
| 1023 |
-
return &sequence_type; \
|
| 1024 |
-
} \
|
| 1025 |
-
template <> \
|
| 1026 |
-
MLDataType DataTypeImpl::GetType<TYPE>() { \
|
| 1027 |
-
return SequenceType<TYPE>::Type(); \
|
| 1028 |
-
}
|
| 1029 |
-
|
| 1030 |
-
#define ORT_REGISTER_SEQ_TENSOR_TYPE(ELEM_TYPE) \
|
| 1031 |
-
template <> \
|
| 1032 |
-
MLDataType SequenceTensorType<ELEM_TYPE>::Type() { \
|
| 1033 |
-
static SequenceTensorType<ELEM_TYPE> sequence_tensor_type; \
|
| 1034 |
-
return &sequence_tensor_type; \
|
| 1035 |
-
} \
|
| 1036 |
-
template <> \
|
| 1037 |
-
MLDataType DataTypeImpl::GetSequenceTensorType<ELEM_TYPE>() { \
|
| 1038 |
-
return SequenceTensorType<ELEM_TYPE>::Type(); \
|
| 1039 |
-
}
|
| 1040 |
-
|
| 1041 |
-
#define ORT_REGISTER_PRIM_TYPE(TYPE) \
|
| 1042 |
-
template <> \
|
| 1043 |
-
MLDataType PrimitiveDataType<TYPE>::Type() { \
|
| 1044 |
-
static PrimitiveDataType<TYPE> prim_data_type; \
|
| 1045 |
-
return &prim_data_type; \
|
| 1046 |
-
} \
|
| 1047 |
-
template <> \
|
| 1048 |
-
MLDataType DataTypeImpl::GetType<TYPE>() { \
|
| 1049 |
-
return PrimitiveDataType<TYPE>::Type(); \
|
| 1050 |
-
}
|
| 1051 |
-
|
| 1052 |
-
#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
|
| 1053 |
-
template <> \
|
| 1054 |
-
MLDataType OpaqueType<CPPType, Domain, Name>::Type() { \
|
| 1055 |
-
static OpaqueType<CPPType, Domain, Name> opaque_type; \
|
| 1056 |
-
return &opaque_type; \
|
| 1057 |
-
} \
|
| 1058 |
-
template <> \
|
| 1059 |
-
MLDataType DataTypeImpl::GetType<CPPType>() { \
|
| 1060 |
-
return OpaqueType<CPPType, Domain, Name>::Type(); \
|
| 1061 |
-
}
|
| 1062 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/data_types_internal.h
DELETED
|
@@ -1,569 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <array>
|
| 7 |
-
#include <cassert>
|
| 8 |
-
#include <cstdint>
|
| 9 |
-
#include <string>
|
| 10 |
-
#include <type_traits>
|
| 11 |
-
#include <vector>
|
| 12 |
-
|
| 13 |
-
#include "boost/mp11.hpp"
|
| 14 |
-
|
| 15 |
-
#include "core/common/common.h"
|
| 16 |
-
#include "core/framework/to_tensor_proto_element_type.h"
|
| 17 |
-
#ifndef SHARED_PROVIDER
|
| 18 |
-
#include "core/common/type_list.h"
|
| 19 |
-
#include "core/framework/data_types.h"
|
| 20 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 21 |
-
#include "onnx/defs/schema.h"
|
| 22 |
-
#else
|
| 23 |
-
#include "onnx/defs/data_type_utils.h"
|
| 24 |
-
#endif
|
| 25 |
-
#include "onnx/onnx_pb.h"
|
| 26 |
-
#include "onnx/onnx-operators_pb.h"
|
| 27 |
-
#endif
|
| 28 |
-
|
| 29 |
-
namespace onnxruntime {
|
| 30 |
-
namespace utils {
|
| 31 |
-
|
| 32 |
-
// The following primitives are strongly recommended for switching on tensor input datatypes for
|
| 33 |
-
// kernel implementations.
|
| 34 |
-
//
|
| 35 |
-
// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
|
| 36 |
-
// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
|
| 37 |
-
// 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
|
| 38 |
-
// if you have a standalone MLDatatType with a sequence of if/else statements.
|
| 39 |
-
// 3) For something in between, we suggest to use CallDispatcher pattern.
|
| 40 |
-
//
|
| 41 |
-
// Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
|
| 42 |
-
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
|
| 43 |
-
|
| 44 |
-
#define DispatchOnTensorType(tensor_type, function, ...) \
|
| 45 |
-
switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
|
| 46 |
-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
|
| 47 |
-
function<float>(__VA_ARGS__); \
|
| 48 |
-
break; \
|
| 49 |
-
case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
|
| 50 |
-
function<bool>(__VA_ARGS__); \
|
| 51 |
-
break; \
|
| 52 |
-
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
|
| 53 |
-
function<double>(__VA_ARGS__); \
|
| 54 |
-
break; \
|
| 55 |
-
case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
|
| 56 |
-
function<std::string>(__VA_ARGS__); \
|
| 57 |
-
break; \
|
| 58 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
|
| 59 |
-
function<int8_t>(__VA_ARGS__); \
|
| 60 |
-
break; \
|
| 61 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
|
| 62 |
-
function<uint8_t>(__VA_ARGS__); \
|
| 63 |
-
break; \
|
| 64 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
|
| 65 |
-
function<int16_t>(__VA_ARGS__); \
|
| 66 |
-
break; \
|
| 67 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
|
| 68 |
-
function<uint16_t>(__VA_ARGS__); \
|
| 69 |
-
break; \
|
| 70 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
|
| 71 |
-
function<int32_t>(__VA_ARGS__); \
|
| 72 |
-
break; \
|
| 73 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
|
| 74 |
-
function<uint32_t>(__VA_ARGS__); \
|
| 75 |
-
break; \
|
| 76 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
|
| 77 |
-
function<int64_t>(__VA_ARGS__); \
|
| 78 |
-
break; \
|
| 79 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
|
| 80 |
-
function<uint64_t>(__VA_ARGS__); \
|
| 81 |
-
break; \
|
| 82 |
-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
|
| 83 |
-
function<MLFloat16>(__VA_ARGS__); \
|
| 84 |
-
break; \
|
| 85 |
-
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
|
| 86 |
-
function<BFloat16>(__VA_ARGS__); \
|
| 87 |
-
break; \
|
| 88 |
-
default: \
|
| 89 |
-
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
#define DispatchOnTensorTypeWithReturn(tensor_type, retval, function, ...) \
|
| 93 |
-
switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
|
| 94 |
-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: \
|
| 95 |
-
retval = function<float>(__VA_ARGS__); \
|
| 96 |
-
break; \
|
| 97 |
-
case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
|
| 98 |
-
retval = function<bool>(__VA_ARGS__); \
|
| 99 |
-
break; \
|
| 100 |
-
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: \
|
| 101 |
-
retval = function<double>(__VA_ARGS__); \
|
| 102 |
-
break; \
|
| 103 |
-
case ONNX_NAMESPACE::TensorProto_DataType_STRING: \
|
| 104 |
-
retval = function<std::string>(__VA_ARGS__); \
|
| 105 |
-
break; \
|
| 106 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
|
| 107 |
-
retval = function<int8_t>(__VA_ARGS__); \
|
| 108 |
-
break; \
|
| 109 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
|
| 110 |
-
retval = function<uint8_t>(__VA_ARGS__); \
|
| 111 |
-
break; \
|
| 112 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: \
|
| 113 |
-
retval = function<uint16_t>(__VA_ARGS__); \
|
| 114 |
-
break; \
|
| 115 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
|
| 116 |
-
retval = function<int16_t>(__VA_ARGS__); \
|
| 117 |
-
break; \
|
| 118 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
|
| 119 |
-
retval = function<int32_t>(__VA_ARGS__); \
|
| 120 |
-
break; \
|
| 121 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
|
| 122 |
-
retval = function<uint32_t>(__VA_ARGS__); \
|
| 123 |
-
break; \
|
| 124 |
-
case ONNX_NAMESPACE::TensorProto_DataType_INT64: \
|
| 125 |
-
retval = function<int64_t>(__VA_ARGS__); \
|
| 126 |
-
break; \
|
| 127 |
-
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: \
|
| 128 |
-
retval = function<uint64_t>(__VA_ARGS__); \
|
| 129 |
-
break; \
|
| 130 |
-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: \
|
| 131 |
-
retval = function<MLFloat16>(__VA_ARGS__); \
|
| 132 |
-
break; \
|
| 133 |
-
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
|
| 134 |
-
retval = function<BFloat16>(__VA_ARGS__); \
|
| 135 |
-
break; \
|
| 136 |
-
default: \
|
| 137 |
-
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
////////////////////////////////////////////////////////////////////////////////
|
| 141 |
-
/// Use the following primitives if you have a few types to switch on so you
|
| 142 |
-
// can write a short sequence of if/else statements.
|
| 143 |
-
|
| 144 |
-
// This is a frequently used check so we make a separate utility function.
|
| 145 |
-
inline bool IsDataTypeString(MLDataType dt_type) {
|
| 146 |
-
auto prim_type = dt_type->AsPrimitiveDataType();
|
| 147 |
-
return (prim_type != nullptr && prim_type->GetDataType() == ONNX_NAMESPACE::TensorProto_DataType_STRING);
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
// Test if MLDataType is a concrete type of PrimitiveDataTypeBase
|
| 151 |
-
// and it is T
|
| 152 |
-
template <class T>
|
| 153 |
-
inline bool IsPrimitiveDataType(MLDataType dt_type) {
|
| 154 |
-
auto prim_type = dt_type->AsPrimitiveDataType();
|
| 155 |
-
return (prim_type != nullptr && prim_type->GetDataType() == ToTensorProtoElementType<T>());
|
| 156 |
-
}
|
| 157 |
-
|
| 158 |
-
// Use after AsPrimitiveDataType() is successful
|
| 159 |
-
// Check if PrimitiveDataTypeBase is of type T
|
| 160 |
-
template <class T>
|
| 161 |
-
inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) {
|
| 162 |
-
assert(prim_type != nullptr);
|
| 163 |
-
return prim_type->GetDataType() == ToTensorProtoElementType<T>();
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
// This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226
|
| 167 |
-
// GCC until very recently does not support template parameter pack expansion within lambda context.
|
| 168 |
-
namespace mltype_dispatcher_internal {
|
| 169 |
-
|
| 170 |
-
// T - type handled by this helper
|
| 171 |
-
class CallableDispatchableHelper {
|
| 172 |
-
int32_t dt_type_; // Type currently dispatched
|
| 173 |
-
size_t called_;
|
| 174 |
-
|
| 175 |
-
public:
|
| 176 |
-
explicit CallableDispatchableHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0) {}
|
| 177 |
-
|
| 178 |
-
// Must return integer to be in a expandable context
|
| 179 |
-
template <class T, class Fn, class... Args>
|
| 180 |
-
int Invoke(Fn&& fn, Args&&... args) {
|
| 181 |
-
if (utils::ToTensorProtoElementType<T>() == dt_type_) {
|
| 182 |
-
std::forward<Fn>(fn)(std::forward<Args>(args)...);
|
| 183 |
-
++called_;
|
| 184 |
-
}
|
| 185 |
-
return 0;
|
| 186 |
-
}
|
| 187 |
-
|
| 188 |
-
void CheckCalledOnce() {
|
| 189 |
-
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
|
| 190 |
-
}
|
| 191 |
-
};
|
| 192 |
-
|
| 193 |
-
// Default policy is to throw an exception.
|
| 194 |
-
// Other policies may set the second result argument accordingly.
|
| 195 |
-
template <class Ret>
|
| 196 |
-
struct UnsupportedTypeDefaultPolicy {
|
| 197 |
-
void operator()(int32_t dt_type, Ret& /*result*/) const {
|
| 198 |
-
ORT_THROW("Unsupported data type: ", dt_type);
|
| 199 |
-
}
|
| 200 |
-
};
|
| 201 |
-
|
| 202 |
-
// Helper with the result type
|
| 203 |
-
template <class Ret, class UnsupportedPolicy>
|
| 204 |
-
class CallableDispatchableRetHelper {
|
| 205 |
-
int32_t dt_type_; // Type currently dispatched
|
| 206 |
-
size_t called_;
|
| 207 |
-
Ret result_;
|
| 208 |
-
|
| 209 |
-
public:
|
| 210 |
-
explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {}
|
| 211 |
-
|
| 212 |
-
Ret Get() {
|
| 213 |
-
// No type was invoked
|
| 214 |
-
if (called_ == 0) {
|
| 215 |
-
UnsupportedPolicy()(dt_type_, result_);
|
| 216 |
-
}
|
| 217 |
-
return result_;
|
| 218 |
-
}
|
| 219 |
-
|
| 220 |
-
// Must return integer to be in a expandable context
|
| 221 |
-
template <class T, class Fn, class... Args>
|
| 222 |
-
int Invoke(Fn&& fn, Args&&... args) {
|
| 223 |
-
if (utils::ToTensorProtoElementType<T>() == dt_type_) {
|
| 224 |
-
result_ = std::forward<Fn>(fn)(std::forward<Args>(args)...);
|
| 225 |
-
++called_;
|
| 226 |
-
}
|
| 227 |
-
return 0;
|
| 228 |
-
}
|
| 229 |
-
};
|
| 230 |
-
|
| 231 |
-
template <typename T>
|
| 232 |
-
using TensorProtoElementTypeConstant =
|
| 233 |
-
std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;
|
| 234 |
-
|
| 235 |
-
using UndefinedTensorProtoElementTypeConstant =
|
| 236 |
-
std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;
|
| 237 |
-
|
| 238 |
-
} // namespace mltype_dispatcher_internal
|
| 239 |
-
|
| 240 |
-
/**
|
| 241 |
-
* This class helps to efficiently dispatch calls to implementation function
|
| 242 |
-
* objects with a tensor element type template argument.
|
| 243 |
-
*
|
| 244 |
-
* The constructor accepts a value corresponding to a tensor element type.
|
| 245 |
-
* For example, it can be obtained from:
|
| 246 |
-
* input_tensor->GetElementType()
|
| 247 |
-
*
|
| 248 |
-
* The Invoke member functions will instantiate and invoke the provided
|
| 249 |
-
* function object template, Fn. Fn must be default constructible. Fn must also
|
| 250 |
-
* have a tensor element type template argument. This type template argument
|
| 251 |
-
* will be the type that corresponds to the value given in the constructor.
|
| 252 |
-
* These functions accept and forward arbitrary function arguments. They ensure
|
| 253 |
-
* that Fn is called once with the type specified in the constructor.
|
| 254 |
-
*
|
| 255 |
-
* @tparam Types The types supported by the implementation. This should be a
|
| 256 |
-
* set of ONNX tensor element types that are supported by ORT.
|
| 257 |
-
*/
|
| 258 |
-
template <typename... Types>
|
| 259 |
-
class MLTypeCallDispatcher {
|
| 260 |
-
using SupportedTypeList = TypeList<Types...>;
|
| 261 |
-
using SupportedTensorProtoElementTypeList =
|
| 262 |
-
boost::mp11::mp_transform<
|
| 263 |
-
mltype_dispatcher_internal::TensorProtoElementTypeConstant, SupportedTypeList>;
|
| 264 |
-
|
| 265 |
-
static_assert(
|
| 266 |
-
boost::mp11::mp_and<
|
| 267 |
-
boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
|
| 268 |
-
boost::mp11::mp_not<
|
| 269 |
-
boost::mp11::mp_set_contains<
|
| 270 |
-
SupportedTensorProtoElementTypeList,
|
| 271 |
-
mltype_dispatcher_internal::UndefinedTensorProtoElementTypeConstant>>>::value,
|
| 272 |
-
"Types must map to a unique set of ONNX tensor element data types supported by ORT.");
|
| 273 |
-
|
| 274 |
-
int32_t dt_type_;
|
| 275 |
-
|
| 276 |
-
public:
|
| 277 |
-
/**
|
| 278 |
-
* Constructor.
|
| 279 |
-
* @param dt_type The value corresponding to the tensor element type to be
|
| 280 |
-
* dispatched to. This can be obtained from
|
| 281 |
-
* input_tensor->GetElementType() or
|
| 282 |
-
* utils::ToTensorProtoElementType<T>().
|
| 283 |
-
*/
|
| 284 |
-
explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {}
|
| 285 |
-
|
| 286 |
-
/**
|
| 287 |
-
* Invokes Fn<T> with the specified arguments.
|
| 288 |
-
*
|
| 289 |
-
* @tparam Fn The function object template.
|
| 290 |
-
* @tparam Args The argument types.
|
| 291 |
-
*/
|
| 292 |
-
template <template <typename...> class Fn, typename... Args>
|
| 293 |
-
void Invoke(Args&&... args) const {
|
| 294 |
-
InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
/**
|
| 298 |
-
* Invokes Fn<..., T> with leading template arguments and the specified
|
| 299 |
-
* arguments.
|
| 300 |
-
*
|
| 301 |
-
* @tparam Fn The function object template.
|
| 302 |
-
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
| 303 |
-
* arguments.
|
| 304 |
-
* @tparam Args The argument types.
|
| 305 |
-
*/
|
| 306 |
-
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
|
| 307 |
-
void InvokeWithLeadingTemplateArgs(Args&&... args) const {
|
| 308 |
-
static_assert(
|
| 309 |
-
boost::mp11::mp_is_list<LeadingTemplateArgTypeList>::value,
|
| 310 |
-
"LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");
|
| 311 |
-
|
| 312 |
-
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
|
| 313 |
-
|
| 314 |
-
// given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
|
| 315 |
-
// call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
|
| 316 |
-
static_cast<void>(std::array<int, sizeof...(Types)>{
|
| 317 |
-
helper.template Invoke<Types>(
|
| 318 |
-
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
|
| 319 |
-
std::forward<Args>(args)...)...});
|
| 320 |
-
|
| 321 |
-
// avoid "unused parameter" warning for the case where Types is empty
|
| 322 |
-
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
|
| 323 |
-
|
| 324 |
-
helper.CheckCalledOnce();
|
| 325 |
-
}
|
| 326 |
-
|
| 327 |
-
/**
|
| 328 |
-
* Invokes Fn<T> with the specified arguments and returns the result.
|
| 329 |
-
*
|
| 330 |
-
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
| 331 |
-
* @tparam Fn The function object template.
|
| 332 |
-
* @tparam Args The argument types.
|
| 333 |
-
*/
|
| 334 |
-
template <class Ret, template <typename...> class Fn, typename... Args>
|
| 335 |
-
Ret InvokeRet(Args&&... args) const {
|
| 336 |
-
return InvokeRetWithUnsupportedPolicy<
|
| 337 |
-
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>>(
|
| 338 |
-
std::forward<Args>(args)...);
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
/**
|
| 342 |
-
* Invokes Fn<T> with the specified arguments and returns the result.
|
| 343 |
-
*
|
| 344 |
-
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
| 345 |
-
* @tparam Fn The function object template.
|
| 346 |
-
* @tparam UnsupportedPolicy The policy used to handle unsupported types.
|
| 347 |
-
* See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
|
| 348 |
-
* for an example.
|
| 349 |
-
* @tparam Args The argument types.
|
| 350 |
-
*/
|
| 351 |
-
template <class Ret, template <typename...> class Fn, class UnsupportedPolicy, typename... Args>
|
| 352 |
-
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
|
| 353 |
-
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
|
| 354 |
-
Ret, Fn, UnsupportedPolicy, TypeList<>>(
|
| 355 |
-
std::forward<Args>(args)...);
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
/**
|
| 359 |
-
* Invokes Fn<..., T> with leading template arguments and the specified
|
| 360 |
-
* arguments and returns the result.
|
| 361 |
-
*
|
| 362 |
-
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
| 363 |
-
* @tparam Fn The function object template.
|
| 364 |
-
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
| 365 |
-
* arguments.
|
| 366 |
-
* @tparam Args The argument types.
|
| 367 |
-
*/
|
| 368 |
-
template <class Ret, template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
|
| 369 |
-
Ret InvokeRetWithLeadingTemplateArgs(Args&&... args) const {
|
| 370 |
-
return InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs<
|
| 371 |
-
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>, LeadingTemplateArgTypeList>(
|
| 372 |
-
std::forward<Args>(args)...);
|
| 373 |
-
}
|
| 374 |
-
|
| 375 |
-
/**
|
| 376 |
-
* Invokes Fn<..., T> with leading template arguments and the specified
|
| 377 |
-
* arguments and returns the result.
|
| 378 |
-
*
|
| 379 |
-
* @tparam Ret The return type. Fn should return a type convertible to Ret.
|
| 380 |
-
* @tparam Fn The function object template.
|
| 381 |
-
* @tparam UnsupportedPolicy The policy used to handle unsupported types.
|
| 382 |
-
* See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
|
| 383 |
-
* for an example.
|
| 384 |
-
* @tparam LeadingTemplateArgTypeList A type list of the leading template
|
| 385 |
-
* arguments.
|
| 386 |
-
* @tparam Args The argument types.
|
| 387 |
-
*/
|
| 388 |
-
template <class Ret,
|
| 389 |
-
template <typename...> class Fn,
|
| 390 |
-
class UnsupportedPolicy,
|
| 391 |
-
typename LeadingTemplateArgTypeList,
|
| 392 |
-
typename... Args>
|
| 393 |
-
Ret InvokeRetWithUnsupportedPolicyAndLeadingTemplateArgs(Args&&... args) const {
|
| 394 |
-
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
|
| 395 |
-
|
| 396 |
-
// given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
|
| 397 |
-
// call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
|
| 398 |
-
static_cast<void>(std::array<int, sizeof...(Types)>{
|
| 399 |
-
helper.template Invoke<Types>(
|
| 400 |
-
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
|
| 401 |
-
std::forward<Args>(args)...)...});
|
| 402 |
-
|
| 403 |
-
// avoid "unused parameter" warning for the case where Types is empty
|
| 404 |
-
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});
|
| 405 |
-
|
| 406 |
-
return helper.Get();
|
| 407 |
-
}
|
| 408 |
-
};
|
| 409 |
-
|
| 410 |
-
// the type MLTypeCallDispatcher<T...> given a type list L<T...>
|
| 411 |
-
template <typename L>
|
| 412 |
-
using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher, L>;
|
| 413 |
-
|
| 414 |
-
namespace data_types_internal {
|
| 415 |
-
|
| 416 |
-
enum class ContainerType : uint16_t {
|
| 417 |
-
kUndefined = 0,
|
| 418 |
-
kTensor = 1,
|
| 419 |
-
kMap = 2,
|
| 420 |
-
kSequence = 3,
|
| 421 |
-
kOpaque = 4,
|
| 422 |
-
kOptional = 5
|
| 423 |
-
};
|
| 424 |
-
|
| 425 |
-
class TypeNode {
|
| 426 |
-
// type_ is a TypeProto value case enum
|
| 427 |
-
// that may be a kTypeTensor, kTypeMap, kTypeSequence
|
| 428 |
-
// prim_type_ is a TypeProto_DataType enum that has meaning
|
| 429 |
-
// - for Tensor then prim_type_ is the contained type
|
| 430 |
-
// - for Map prim_type is the key type. Next entry describes map value
|
| 431 |
-
// - For sequence prim_type_ is not used and has no meaning. Next entry
|
| 432 |
-
// describes the value for the sequence
|
| 433 |
-
// Tensor is always the last entry as it describes a contained primitive type.
|
| 434 |
-
ContainerType type_;
|
| 435 |
-
uint16_t prim_type_;
|
| 436 |
-
|
| 437 |
-
public:
|
| 438 |
-
TypeNode(ContainerType type, int32_t prim_type) noexcept {
|
| 439 |
-
type_ = type;
|
| 440 |
-
prim_type_ = static_cast<uint16_t>(prim_type);
|
| 441 |
-
}
|
| 442 |
-
|
| 443 |
-
bool IsType(ContainerType type) const noexcept {
|
| 444 |
-
return type_ == type;
|
| 445 |
-
}
|
| 446 |
-
|
| 447 |
-
bool IsPrimType(int32_t prim_type) const noexcept {
|
| 448 |
-
return prim_type_ == static_cast<uint16_t>(prim_type);
|
| 449 |
-
}
|
| 450 |
-
};
|
| 451 |
-
|
| 452 |
-
} // namespace data_types_internal
|
| 453 |
-
|
| 454 |
-
////////////////////////////////////////////////////////////////////
|
| 455 |
-
/// Provides generic interface to test whether MLDataType is a Sequence,
|
| 456 |
-
/// Map or an Opaque type including arbitrary recursive definitions
|
| 457 |
-
/// without querying DataTypeImpl::GetType<T> for all known complex types
|
| 458 |
-
|
| 459 |
-
// T is a sequence contained element type
|
| 460 |
-
// If returns true then we know that the runtime
|
| 461 |
-
// representation is std::vector<T>
|
| 462 |
-
// T itself can be a runtime representation of another
|
| 463 |
-
// sequence, map, opaque type or a tensor
|
| 464 |
-
//
|
| 465 |
-
// That is it can be std::vector or a std::map
|
| 466 |
-
// If T is a primitive type sequence is tested whether it contains
|
| 467 |
-
// tensors of that type
|
| 468 |
-
//
|
| 469 |
-
// If T is an opaque type, then it is only tested to be opaque but not exactly
|
| 470 |
-
// a specific opaque type. To Test for a specific Opaque type use IsOpaqueType() below
|
| 471 |
-
//
|
| 472 |
-
// This class examines the supplied MLDataType and records
|
| 473 |
-
// its information in a vector so any subsequent checks for Sequences and Maps
|
| 474 |
-
// are quick.
|
| 475 |
-
class ContainerChecker {
|
| 476 |
-
using Cont = std::vector<data_types_internal::TypeNode>;
|
| 477 |
-
Cont types_;
|
| 478 |
-
|
| 479 |
-
// Default IsContainerOfType is for Opaque type
|
| 480 |
-
template <class T>
|
| 481 |
-
struct IsContainerOfType {
|
| 482 |
-
static bool check(const Cont& c, size_t index) {
|
| 483 |
-
if (index >= c.size()) {
|
| 484 |
-
return false;
|
| 485 |
-
}
|
| 486 |
-
return c[index].IsType(data_types_internal::ContainerType::kOpaque);
|
| 487 |
-
}
|
| 488 |
-
};
|
| 489 |
-
|
| 490 |
-
// Handles the case where sequence element is also a sequence
|
| 491 |
-
template <class T>
|
| 492 |
-
struct IsContainerOfType<std::vector<T>> {
|
| 493 |
-
static bool check(const Cont& c, size_t index) {
|
| 494 |
-
if (index >= c.size()) {
|
| 495 |
-
return false;
|
| 496 |
-
}
|
| 497 |
-
if (c[index].IsType(data_types_internal::ContainerType::kSequence)) {
|
| 498 |
-
ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
|
| 499 |
-
constexpr int32_t prim_type = ToTensorProtoElementType<T>();
|
| 500 |
-
// Check if this is a primitive type and it matches
|
| 501 |
-
if constexpr(prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
| 502 |
-
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
|
| 503 |
-
c[index].IsPrimType(prim_type);
|
| 504 |
-
}
|
| 505 |
-
else {
|
| 506 |
-
// T is not primitive, check next entry for non-primitive proto
|
| 507 |
-
return IsContainerOfType<T>::check(c, index);
|
| 508 |
-
}
|
| 509 |
-
}
|
| 510 |
-
return false;
|
| 511 |
-
}
|
| 512 |
-
};
|
| 513 |
-
|
| 514 |
-
template <class K, class V>
|
| 515 |
-
struct IsContainerOfType<std::map<K, V>> {
|
| 516 |
-
static bool check(const Cont& c, size_t index) {
|
| 517 |
-
static_assert(ToTensorProtoElementType<K>() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED,
|
| 518 |
-
"Map Key can not be a non-primitive type");
|
| 519 |
-
if (index >= c.size()) {
|
| 520 |
-
return false;
|
| 521 |
-
}
|
| 522 |
-
if (!c[index].IsType(data_types_internal::ContainerType::kMap)) {
|
| 523 |
-
return false;
|
| 524 |
-
}
|
| 525 |
-
constexpr int32_t key_type = ToTensorProtoElementType<K>();
|
| 526 |
-
if (!c[index].IsPrimType(key_type)) {
|
| 527 |
-
return false;
|
| 528 |
-
}
|
| 529 |
-
ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
|
| 530 |
-
constexpr int32_t val_type = ToTensorProtoElementType<V>();
|
| 531 |
-
if constexpr(val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
| 532 |
-
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
|
| 533 |
-
c[index].IsPrimType(val_type);
|
| 534 |
-
}
|
| 535 |
-
else return IsContainerOfType<V>::check(c, index);
|
| 536 |
-
}
|
| 537 |
-
};
|
| 538 |
-
|
| 539 |
-
public:
|
| 540 |
-
explicit ContainerChecker(MLDataType);
|
| 541 |
-
~ContainerChecker() = default;
|
| 542 |
-
|
| 543 |
-
bool IsMap() const noexcept {
|
| 544 |
-
assert(!types_.empty());
|
| 545 |
-
return types_[0].IsType(data_types_internal::ContainerType::kMap);
|
| 546 |
-
}
|
| 547 |
-
|
| 548 |
-
bool IsSequence() const noexcept {
|
| 549 |
-
assert(!types_.empty());
|
| 550 |
-
return types_[0].IsType(data_types_internal::ContainerType::kSequence);
|
| 551 |
-
}
|
| 552 |
-
|
| 553 |
-
template <class T>
|
| 554 |
-
bool IsSequenceOf() const {
|
| 555 |
-
assert(!types_.empty());
|
| 556 |
-
return IsContainerOfType<std::vector<T>>::check(types_, 0);
|
| 557 |
-
}
|
| 558 |
-
|
| 559 |
-
template <class K, class V>
|
| 560 |
-
bool IsMapOf() const {
|
| 561 |
-
assert(!types_.empty());
|
| 562 |
-
return IsContainerOfType<std::map<K, V>>::check(types_, 0);
|
| 563 |
-
}
|
| 564 |
-
};
|
| 565 |
-
|
| 566 |
-
bool IsOpaqueType(MLDataType ml_type, const char* domain, const char* name);
|
| 567 |
-
|
| 568 |
-
} // namespace utils
|
| 569 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/endian.h
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
namespace onnxruntime {
|
| 7 |
-
|
| 8 |
-
// the semantics of this enum should match std::endian from C++20
|
| 9 |
-
enum class endian {
|
| 10 |
-
#if defined(_WIN32)
|
| 11 |
-
little = 0,
|
| 12 |
-
big = 1,
|
| 13 |
-
native = little,
|
| 14 |
-
#elif defined(__GNUC__) || defined(__clang__)
|
| 15 |
-
little = __ORDER_LITTLE_ENDIAN__,
|
| 16 |
-
big = __ORDER_BIG_ENDIAN__,
|
| 17 |
-
native = __BYTE_ORDER__,
|
| 18 |
-
#else
|
| 19 |
-
#error onnxruntime::endian is not implemented in this environment.
|
| 20 |
-
#endif
|
| 21 |
-
};
|
| 22 |
-
|
| 23 |
-
static_assert(
|
| 24 |
-
endian::native == endian::little || endian::native == endian::big,
|
| 25 |
-
"Only little-endian or big-endian native byte orders are supported.");
|
| 26 |
-
|
| 27 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/execution_provider.h
DELETED
|
@@ -1,340 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#ifndef SHARED_PROVIDER
|
| 7 |
-
#include <memory>
|
| 8 |
-
#include <unordered_map>
|
| 9 |
-
#include <unordered_set>
|
| 10 |
-
|
| 11 |
-
#include "core/common/logging/logging.h"
|
| 12 |
-
#include "core/common/status.h"
|
| 13 |
-
#include "core/framework/data_transfer.h"
|
| 14 |
-
#include "core/framework/tensor.h"
|
| 15 |
-
|
| 16 |
-
namespace onnxruntime {
|
| 17 |
-
class GraphViewer;
|
| 18 |
-
struct ComputeCapability;
|
| 19 |
-
class KernelRegistry;
|
| 20 |
-
struct KernelCreateInfo;
|
| 21 |
-
class Node;
|
| 22 |
-
} // namespace onnxruntime
|
| 23 |
-
#else
|
| 24 |
-
#include <memory>
|
| 25 |
-
#endif
|
| 26 |
-
|
| 27 |
-
#include "core/common/basic_types.h"
|
| 28 |
-
#include "core/common/profiler_common.h"
|
| 29 |
-
#include "core/framework/allocatormgr.h"
|
| 30 |
-
#include "core/framework/func_api.h"
|
| 31 |
-
#include "core/framework/provider_options.h"
|
| 32 |
-
#include "core/framework/stream_handles.h"
|
| 33 |
-
#include "core/framework/tuning_context.h"
|
| 34 |
-
|
| 35 |
-
namespace onnxruntime {
|
| 36 |
-
|
| 37 |
-
/**
|
| 38 |
-
Logical device representation.
|
| 39 |
-
*/
|
| 40 |
-
|
| 41 |
-
// if we are export the fused function to dll, the function will still in the same binary as onnxruntime
|
| 42 |
-
// use std function to give execution provider some chance to capture some state.
|
| 43 |
-
using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
|
| 44 |
-
using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
|
| 45 |
-
using DestroyFunctionStateFunc = std::function<void(FunctionState)>;
|
| 46 |
-
|
| 47 |
-
struct NodeComputeInfo {
|
| 48 |
-
CreateFunctionStateFunc create_state_func;
|
| 49 |
-
ComputeFunc compute_func;
|
| 50 |
-
DestroyFunctionStateFunc release_state_func;
|
| 51 |
-
};
|
| 52 |
-
|
| 53 |
-
enum class DataLayout {
|
| 54 |
-
NCHW,
|
| 55 |
-
NHWC,
|
| 56 |
-
NCHWC,
|
| 57 |
-
};
|
| 58 |
-
|
| 59 |
-
class IExecutionProvider {
|
| 60 |
-
protected:
|
| 61 |
-
IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
|
| 62 |
-
: type_{type} {
|
| 63 |
-
if (use_metadef_id_creator) {
|
| 64 |
-
metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
|
| 65 |
-
}
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
public:
|
| 69 |
-
virtual ~IExecutionProvider() = default;
|
| 70 |
-
|
| 71 |
-
/**
|
| 72 |
-
Get all IAllocators for <*this> execution provider.
|
| 73 |
-
*/
|
| 74 |
-
const std::vector<AllocatorPtr>& GetAllocators() const {
|
| 75 |
-
return allocator_list_;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
/**
|
| 79 |
-
* Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist
|
| 80 |
-
*/
|
| 81 |
-
virtual AllocatorPtr GetAllocator(OrtMemType mem_type) const;
|
| 82 |
-
|
| 83 |
-
/**
|
| 84 |
-
* Returns a data transfer object that implements methods to copy to and
|
| 85 |
-
* from this device.
|
| 86 |
-
* If no copy is required for the successful operation of this provider,
|
| 87 |
-
* return a nullptr.
|
| 88 |
-
*/
|
| 89 |
-
virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
|
| 90 |
-
return nullptr;
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
/**
|
| 94 |
-
* Interface for performing kernel lookup within kernel registries.
|
| 95 |
-
* Abstracts away lower-level details about kernel registries and kernel matching.
|
| 96 |
-
*/
|
| 97 |
-
class IKernelLookup {
|
| 98 |
-
public:
|
| 99 |
-
/**
|
| 100 |
-
* Given `node`, try to find a matching kernel for this EP.
|
| 101 |
-
* The return value is non-null if and only if a matching kernel was found.
|
| 102 |
-
*/
|
| 103 |
-
virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;
|
| 104 |
-
|
| 105 |
-
protected:
|
| 106 |
-
~IKernelLookup() = default;
|
| 107 |
-
};
|
| 108 |
-
|
| 109 |
-
/**
|
| 110 |
-
Get execution provider's capability for the specified <graph>.
|
| 111 |
-
Return a bunch of IndexedSubGraphs <*this> execution provider can run if
|
| 112 |
-
the sub-graph contains only one node or can fuse to run if the sub-graph
|
| 113 |
-
contains more than one node. The node indexes contained in sub-graphs may
|
| 114 |
-
have overlap, and it's ONNXRuntime's responsibility to do the partition
|
| 115 |
-
and decide whether a node will be assigned to <*this> execution provider.
|
| 116 |
-
For kernels registered in a kernel registry, `kernel_lookup` must be used
|
| 117 |
-
to find a matching kernel for this EP.
|
| 118 |
-
*/
|
| 119 |
-
virtual std::vector<std::unique_ptr<ComputeCapability>>
|
| 120 |
-
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
| 121 |
-
const IKernelLookup& kernel_lookup) const;
|
| 122 |
-
|
| 123 |
-
/**
|
| 124 |
-
Get kernel registry per execution provider type.
|
| 125 |
-
The KernelRegistry share pointer returned is shared across sessions.
|
| 126 |
-
|
| 127 |
-
NOTE: this approach was taken to achieve the following goals,
|
| 128 |
-
1. The execution provider type based kernel registry should be shared
|
| 129 |
-
across sessions.
|
| 130 |
-
Only one copy of this kind of kernel registry exists in ONNXRuntime
|
| 131 |
-
with multiple sessions/models.
|
| 132 |
-
2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
|
| 133 |
-
framework/session code.
|
| 134 |
-
3. onnxruntime (framework/session) does not depend on any specific
|
| 135 |
-
execution provider lib.
|
| 136 |
-
*/
|
| 137 |
-
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }
|
| 138 |
-
|
| 139 |
-
/**
|
| 140 |
-
Get the device id of current execution provider
|
| 141 |
-
*/
|
| 142 |
-
virtual int GetDeviceId() const { return 0; };
|
| 143 |
-
|
| 144 |
-
/**
|
| 145 |
-
Get execution provider's configuration options.
|
| 146 |
-
*/
|
| 147 |
-
virtual ProviderOptions GetProviderOptions() const { return {}; }
|
| 148 |
-
|
| 149 |
-
/**
|
| 150 |
-
Returns an opaque handle whose exact type varies based on the provider
|
| 151 |
-
and is interpreted accordingly by the corresponding kernel implementation.
|
| 152 |
-
For Direct3D operator kernels, this may return an IUnknown supporting
|
| 153 |
-
QueryInterface to ID3D12GraphicsCommandList1.
|
| 154 |
-
*/
|
| 155 |
-
virtual const void* GetExecutionHandle() const noexcept {
|
| 156 |
-
return nullptr;
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
/**
|
| 160 |
-
@return type of the execution provider; should match that set in the node
|
| 161 |
-
through the SetExecutionProvider API. Example valid return values are:
|
| 162 |
-
kCpuExecutionProvider, kCudaExecutionProvider
|
| 163 |
-
*/
|
| 164 |
-
const std::string& Type() const { return type_; }
|
| 165 |
-
|
| 166 |
-
/**
|
| 167 |
-
Blocks until the device has completed all preceding requested tasks.
|
| 168 |
-
Currently this is primarily used by the IOBinding object to ensure that all
|
| 169 |
-
inputs have been copied to the device before execution begins.
|
| 170 |
-
*/
|
| 171 |
-
virtual common::Status Sync() const { return Status::OK(); }
|
| 172 |
-
|
| 173 |
-
/**
|
| 174 |
-
Called when InferenceSession::Run started
|
| 175 |
-
NOTE that due to async execution in provider, the actual work of previous
|
| 176 |
-
Run may not be finished on device This function should be regarded as the
|
| 177 |
-
point after which a new Run would start to submit commands from CPU
|
| 178 |
-
*/
|
| 179 |
-
virtual common::Status OnRunStart() { return Status::OK(); }
|
| 180 |
-
|
| 181 |
-
/**
|
| 182 |
-
Called when InferenceSession::Run ended
|
| 183 |
-
NOTE that due to async execution in provider, the actual work of this Run
|
| 184 |
-
may not be finished on device This function should be regarded as the point
|
| 185 |
-
that all commands of current Run has been submmited by CPU
|
| 186 |
-
*/
|
| 187 |
-
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
|
| 188 |
-
|
| 189 |
-
/**
|
| 190 |
-
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
|
| 191 |
-
the provider. Currently only CUDA execution provider supports it.
|
| 192 |
-
*/
|
| 193 |
-
virtual bool IsGraphCaptureEnabled() const { return false; }
|
| 194 |
-
|
| 195 |
-
/**
|
| 196 |
-
Indicate whether the graph has been captured and instantiated. Currently
|
| 197 |
-
only CUDA execution provider supports it.
|
| 198 |
-
*/
|
| 199 |
-
virtual bool IsGraphCaptured() const { return false; }
|
| 200 |
-
|
| 201 |
-
/**
|
| 202 |
-
Run the instantiated graph. Currently only CUDA execution provider supports
|
| 203 |
-
it.
|
| 204 |
-
*/
|
| 205 |
-
virtual common::Status ReplayGraph() { return Status::OK(); }
|
| 206 |
-
|
| 207 |
-
/**
|
| 208 |
-
Called when session creation is complete
|
| 209 |
-
This provides an opportunity for execution providers to optionally synchronize and
|
| 210 |
-
clean up its temporary resources to reduce memory and ensure the first run is fast.
|
| 211 |
-
*/
|
| 212 |
-
virtual common::Status OnSessionInitializationEnd() { return Status::OK(); }
|
| 213 |
-
|
| 214 |
-
void InsertAllocator(AllocatorPtr allocator);
|
| 215 |
-
void ReplaceAllocator(AllocatorPtr allocator);
|
| 216 |
-
|
| 217 |
-
struct FusedNodeAndGraph {
|
| 218 |
-
const std::reference_wrapper<onnxruntime::Node> fused_node;
|
| 219 |
-
// GraphViewer that filters the full graph to the nodes that are covered by 'node'
|
| 220 |
-
const std::reference_wrapper<GraphViewer> filtered_graph;
|
| 221 |
-
};
|
| 222 |
-
|
| 223 |
-
// Fusion approach that is suppported
|
| 224 |
-
// !!! The "Function" FusionStyle is deprecated.
|
| 225 |
-
// !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
|
| 226 |
-
enum class FusionStyle {
|
| 227 |
-
// The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
|
| 228 |
-
// in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
|
| 229 |
-
// A GraphProto can be produced from the Node body.
|
| 230 |
-
Function,
|
| 231 |
-
|
| 232 |
-
// The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
|
| 233 |
-
// that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
|
| 234 |
-
// Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
|
| 235 |
-
// This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
|
| 236 |
-
// and can be supported in a minimal build.
|
| 237 |
-
FilteredGraphViewer
|
| 238 |
-
};
|
| 239 |
-
|
| 240 |
-
virtual FusionStyle GetFusionStyle() const {
|
| 241 |
-
// All the ORT build in EP has migrate to FilteredGraphViewer style.
|
| 242 |
-
// For newer EPs, please avoid use Function style as it is deprecated.
|
| 243 |
-
return FusionStyle::FilteredGraphViewer;
|
| 244 |
-
}
|
| 245 |
-
|
| 246 |
-
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
| 247 |
-
/**
|
| 248 |
-
Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
|
| 249 |
-
return create_state/compute/release_state func for each node.
|
| 250 |
-
@remarks This is now the default interface when execution provider wants to compile nodes
|
| 251 |
-
for both minimal build and complete ort build.
|
| 252 |
-
|
| 253 |
-
Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
|
| 254 |
-
as it is only valid for the duration of the call to Compile.
|
| 255 |
-
*/
|
| 256 |
-
virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
|
| 257 |
-
std::vector<NodeComputeInfo>& node_compute_funcs);
|
| 258 |
-
|
| 259 |
-
#endif
|
| 260 |
-
|
| 261 |
-
void SetLogger(const logging::Logger* logger) {
|
| 262 |
-
logger_ = logger;
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
-
const logging::Logger* GetLogger() const {
|
| 266 |
-
return logger_;
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
/** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
|
| 270 |
-
The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
|
| 271 |
-
@param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
|
| 272 |
-
@param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
|
| 273 |
-
This is created using the model path if available,
|
| 274 |
-
or the model input names and the output names from all nodes in the main graph.
|
| 275 |
-
@remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
|
| 276 |
-
compiled kernels, so the name must be unique and deterministic across models and sessions.
|
| 277 |
-
NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
|
| 278 |
-
virtual, and ModelMetadefIdGenerator but be defined in the header as well.
|
| 279 |
-
*/
|
| 280 |
-
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;
|
| 281 |
-
|
| 282 |
-
/**
|
| 283 |
-
Register allocators for EP, potentially re-using existing allocators for a device from allocator_manager.
|
| 284 |
-
If the EP implements this it should generally delay creating any allocators until this is called.
|
| 285 |
-
*/
|
| 286 |
-
virtual void RegisterAllocator(AllocatorManager& /*allocator_manager*/);
|
| 287 |
-
|
| 288 |
-
virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
|
| 289 |
-
return {};
|
| 290 |
-
}
|
| 291 |
-
|
| 292 |
-
virtual DataLayout GetPreferredLayout() const {
|
| 293 |
-
// NCHW is the default ONNX standard data layout. So default to it.
|
| 294 |
-
// EPs which prefer a different layout should override to return their preferred layout.
|
| 295 |
-
return DataLayout::NCHW;
|
| 296 |
-
}
|
| 297 |
-
|
| 298 |
-
virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/) const {}
|
| 299 |
-
|
| 300 |
-
/** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
|
| 301 |
-
*/
|
| 302 |
-
virtual bool ConcurrentRunSupported() const { return true; }
|
| 303 |
-
|
| 304 |
-
/**
|
| 305 |
-
* Return the tuning context which holds all TunableOp state.
|
| 306 |
-
*/
|
| 307 |
-
virtual ITuningContext* GetTuningContext() const {
|
| 308 |
-
return nullptr;
|
| 309 |
-
}
|
| 310 |
-
|
| 311 |
-
private:
|
| 312 |
-
const std::string type_;
|
| 313 |
-
|
| 314 |
-
// allocator lookup is done by combining the device id and OrtMemType.
|
| 315 |
-
// there's also an implicit connection to the underlying OrtDevice involved that is dependent on the EP.
|
| 316 |
-
// e.g. for a CPU based EP, 'default' memory is a CPU device, and for a GPU based EP 'default' memory is a
|
| 317 |
-
// GPU device.
|
| 318 |
-
using AllocatorMap = std::unordered_map<int, AllocatorPtr>;
|
| 319 |
-
AllocatorMap allocators_;
|
| 320 |
-
|
| 321 |
-
// It will be set when this object is registered to a session
|
| 322 |
-
const logging::Logger* logger_ = nullptr;
|
| 323 |
-
// convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
|
| 324 |
-
// contains the same instances as allocators_
|
| 325 |
-
std::vector<AllocatorPtr> allocator_list_;
|
| 326 |
-
|
| 327 |
-
// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
|
| 328 |
-
// multiple sessions.
|
| 329 |
-
class ModelMetadefIdGenerator {
|
| 330 |
-
public:
|
| 331 |
-
int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);
|
| 332 |
-
|
| 333 |
-
private:
|
| 334 |
-
std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
|
| 335 |
-
std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
|
| 336 |
-
};
|
| 337 |
-
|
| 338 |
-
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
|
| 339 |
-
};
|
| 340 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/float16.h
DELETED
|
@@ -1,159 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
#pragma once
|
| 4 |
-
|
| 5 |
-
#include "endian.h"
|
| 6 |
-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
| 7 |
-
#include "cuda_bf16.h"
|
| 8 |
-
#endif
|
| 9 |
-
|
| 10 |
-
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
| 11 |
-
#include "core/common/narrow.h"
|
| 12 |
-
#endif
|
| 13 |
-
|
| 14 |
-
#include "core/common/common.h"
|
| 15 |
-
|
| 16 |
-
namespace onnxruntime {
|
| 17 |
-
|
| 18 |
-
#if defined(__CUDACC__) || defined(__HIPCC__)
|
| 19 |
-
#define ORT_HOST_DEVICE __host__ __device__
|
| 20 |
-
#else
|
| 21 |
-
#define ORT_HOST_DEVICE
|
| 22 |
-
#endif
|
| 23 |
-
|
| 24 |
-
// MLFloat16
|
| 25 |
-
struct MLFloat16 {
|
| 26 |
-
uint16_t val{0};
|
| 27 |
-
|
| 28 |
-
MLFloat16() = default;
|
| 29 |
-
explicit constexpr MLFloat16(uint16_t x) : val(x) {}
|
| 30 |
-
explicit MLFloat16(float f);
|
| 31 |
-
|
| 32 |
-
float ToFloat() const;
|
| 33 |
-
|
| 34 |
-
operator float() const { return ToFloat(); }
|
| 35 |
-
};
|
| 36 |
-
|
| 37 |
-
inline bool operator==(const MLFloat16& left, const MLFloat16& right) { return left.val == right.val; }
|
| 38 |
-
inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { return left.val != right.val; }
|
| 39 |
-
inline bool operator<(const MLFloat16& left, const MLFloat16& right) { return left.val < right.val; }
|
| 40 |
-
|
| 41 |
-
// BFloat16
|
| 42 |
-
struct BFloat16 {
|
| 43 |
-
uint16_t val{0};
|
| 44 |
-
#if defined(__HIP__)
|
| 45 |
-
ORT_HOST_DEVICE BFloat16() = default;
|
| 46 |
-
#else
|
| 47 |
-
BFloat16() = default;
|
| 48 |
-
#endif
|
| 49 |
-
|
| 50 |
-
struct FromBitsT {};
|
| 51 |
-
static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); }
|
| 52 |
-
constexpr ORT_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) : val(bits) {}
|
| 53 |
-
|
| 54 |
-
inline ORT_HOST_DEVICE BFloat16(float v) {
|
| 55 |
-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
| 56 |
-
val = __bfloat16_as_ushort(__float2bfloat16(v));
|
| 57 |
-
#elif defined(__HIP__)
|
| 58 |
-
// We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
|
| 59 |
-
if (v != v) { // isnan
|
| 60 |
-
val = UINT16_C(0x7FC0);
|
| 61 |
-
} else {
|
| 62 |
-
union {
|
| 63 |
-
uint32_t U32;
|
| 64 |
-
float F32;
|
| 65 |
-
};
|
| 66 |
-
|
| 67 |
-
F32 = v;
|
| 68 |
-
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
| 69 |
-
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
| 70 |
-
}
|
| 71 |
-
#else
|
| 72 |
-
if constexpr(endian::native == endian::little) {
|
| 73 |
-
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
|
| 74 |
-
}
|
| 75 |
-
else {
|
| 76 |
-
std::memcpy(&val, &v, sizeof(uint16_t));
|
| 77 |
-
}
|
| 78 |
-
#endif
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
inline ORT_HOST_DEVICE float ToFloat() const {
|
| 82 |
-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
| 83 |
-
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&val));
|
| 84 |
-
#elif defined(__HIP__)
|
| 85 |
-
// We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment.
|
| 86 |
-
float result = 0;
|
| 87 |
-
uint32_t tmp = val;
|
| 88 |
-
tmp <<= 16;
|
| 89 |
-
float* tempRes = reinterpret_cast<float*>(&tmp);
|
| 90 |
-
result = *tempRes;
|
| 91 |
-
return result;
|
| 92 |
-
#else
|
| 93 |
-
float result;
|
| 94 |
-
char* const first = reinterpret_cast<char*>(&result);
|
| 95 |
-
char* const second = first + sizeof(uint16_t);
|
| 96 |
-
if constexpr(endian::native == endian::little) {
|
| 97 |
-
std::memset(first, 0, sizeof(uint16_t));
|
| 98 |
-
std::memcpy(second, &val, sizeof(uint16_t));
|
| 99 |
-
}
|
| 100 |
-
else {
|
| 101 |
-
std::memcpy(first, &val, sizeof(uint16_t));
|
| 102 |
-
std::memset(second, 0, sizeof(uint16_t));
|
| 103 |
-
}
|
| 104 |
-
return result;
|
| 105 |
-
#endif
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
|
| 109 |
-
|
| 110 |
-
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
| 111 |
-
ORT_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast<const unsigned short*>(&value); }
|
| 112 |
-
explicit ORT_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast<const __nv_bfloat16*>(&val); }
|
| 113 |
-
#endif
|
| 114 |
-
};
|
| 115 |
-
|
| 116 |
-
inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& right) { return left.val == right.val; }
|
| 117 |
-
inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
|
| 118 |
-
inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
// User defined suffixes to make it easier to declare
|
| 122 |
-
// initializers with MLFloat16 and BFloat16 from unsigned short
|
| 123 |
-
// E.g 10_f16 or 10_b16
|
| 124 |
-
#if !defined(__CUDACC__) && !defined(__HIPCC__)
|
| 125 |
-
inline MLFloat16 operator"" _f16(unsigned long long int v) {
|
| 126 |
-
return MLFloat16(narrow<uint16_t>(v));
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
inline MLFloat16 operator"" _fp16(long double v) {
|
| 130 |
-
return MLFloat16(static_cast<float>(v));
|
| 131 |
-
}
|
| 132 |
-
|
| 133 |
-
inline BFloat16 operator"" _b16(unsigned long long int v) {
|
| 134 |
-
return BFloat16(narrow<uint16_t>(v), BFloat16::FromBits());
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
inline BFloat16 operator"" _bfp16(long double v) {
|
| 138 |
-
return BFloat16(static_cast<float>(v));
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
#endif
|
| 142 |
-
|
| 143 |
-
inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) {
|
| 144 |
-
auto src = blf;
|
| 145 |
-
auto d = flt;
|
| 146 |
-
for (; size != 0; ++src, ++d, --size) {
|
| 147 |
-
*d = src->ToFloat();
|
| 148 |
-
}
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) {
|
| 152 |
-
auto src = flt;
|
| 153 |
-
auto d = blf;
|
| 154 |
-
for (; size != 0; ++src, ++d, --size) {
|
| 155 |
-
new (d) BFloat16(*src);
|
| 156 |
-
}
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/framework_common.h
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string>
|
| 7 |
-
#include <unordered_map>
|
| 8 |
-
#include <vector>
|
| 9 |
-
#include "run_options.h"
|
| 10 |
-
|
| 11 |
-
namespace onnxruntime { // forward declarations
|
| 12 |
-
class Model;
|
| 13 |
-
class GraphTransformer;
|
| 14 |
-
class NodeArg;
|
| 15 |
-
} // namespace onnxruntime
|
| 16 |
-
|
| 17 |
-
namespace onnxruntime {
|
| 18 |
-
using InputDefList = std::vector<const onnxruntime::NodeArg*>;
|
| 19 |
-
using OutputDefList = std::vector<const onnxruntime::NodeArg*>;
|
| 20 |
-
|
| 21 |
-
using NameMLValMap = std::unordered_map<std::string, OrtValue>;
|
| 22 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/func_api.h
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
#pragma once
|
| 2 |
-
#include "core/common/status.h"
|
| 3 |
-
using onnxruntime::common::Status; // TODO: Needed by WinML, but shouldn't be put into the global namespace like this
|
| 4 |
-
|
| 5 |
-
namespace onnxruntime {
|
| 6 |
-
|
| 7 |
-
// AllocateFunc(void* handle, size_t alignment, size_t size)
|
| 8 |
-
using AllocateFunc = void* (*)(void*, size_t, size_t);
|
| 9 |
-
using DestroyFunc = void (*)(void*, void*);
|
| 10 |
-
using AllocatorHandle = void*;
|
| 11 |
-
|
| 12 |
-
typedef struct {
|
| 13 |
-
//right now we only include allocation for host memory
|
| 14 |
-
AllocateFunc allocate_func;
|
| 15 |
-
DestroyFunc release_func;
|
| 16 |
-
AllocatorHandle allocator_handle;
|
| 17 |
-
const char* node_name;
|
| 18 |
-
} ComputeContext;
|
| 19 |
-
|
| 20 |
-
using FunctionState = void*;
|
| 21 |
-
// take the ComputeContext, and create a function state.
|
| 22 |
-
using CreateFunctionStateC = int (*)(ComputeContext*, FunctionState*);
|
| 23 |
-
// pass in the function state and input/output tensors, perform compute and return status
|
| 24 |
-
using ComputeFuncC = common::Status (*)(FunctionState, const OrtApi*, OrtKernelContext*);
|
| 25 |
-
// release the function state.
|
| 26 |
-
using DestroyFunctionStateC = void (*)(FunctionState);
|
| 27 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/kernel_def_builder.h
DELETED
|
@@ -1,353 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <limits.h>
|
| 7 |
-
#include <memory>
|
| 8 |
-
#include <optional>
|
| 9 |
-
#include <string>
|
| 10 |
-
#include <unordered_map>
|
| 11 |
-
#include <vector>
|
| 12 |
-
|
| 13 |
-
#include "core/common/common.h"
|
| 14 |
-
#include "core/framework/allocator.h"
|
| 15 |
-
#include "core/framework/data_types.h"
|
| 16 |
-
#include "core/graph/basic_types.h"
|
| 17 |
-
|
| 18 |
-
namespace onnxruntime {
|
| 19 |
-
class KernelDefBuilder;
|
| 20 |
-
|
| 21 |
-
typedef std::map<size_t, OrtMemType> MemTypeMap;
|
| 22 |
-
|
| 23 |
-
class KernelDef {
|
| 24 |
-
private:
|
| 25 |
-
// note that input/output might be on CPU implicitly when the node is from CPU execution provider
|
| 26 |
-
constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
|
| 27 |
-
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
|
| 28 |
-
}
|
| 29 |
-
|
| 30 |
-
public:
|
| 31 |
-
explicit KernelDef() = default;
|
| 32 |
-
|
| 33 |
-
const std::string& OpName() const {
|
| 34 |
-
return op_name_;
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
const std::string& Domain() const {
|
| 38 |
-
return op_domain_;
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
void SinceVersion(/*out*/ int* start, /*out*/ int* end) const {
|
| 42 |
-
*start = op_since_version_start_;
|
| 43 |
-
*end = op_since_version_end_;
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
| 47 |
-
const std::pair<int, int> SinceVersion() const {
|
| 48 |
-
return std::pair<int, int>(op_since_version_start_, op_since_version_end_);
|
| 49 |
-
}
|
| 50 |
-
#endif
|
| 51 |
-
|
| 52 |
-
onnxruntime::ProviderType Provider() const {
|
| 53 |
-
return provider_type_;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
// type constraints with types supported in this build
|
| 57 |
-
const std::unordered_map<std::string, std::vector<MLDataType>>& TypeConstraints() const {
|
| 58 |
-
return type_constraints_;
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
const std::vector<std::pair<int, int>>& MayInplace() const {
|
| 62 |
-
return inplace_map_;
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
const std::vector<std::pair<int, int>>& Alias() const {
|
| 66 |
-
return alias_map_;
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
const std::optional<std::pair<int, int>>& VariadicAlias() const {
|
| 70 |
-
return variadic_alias_offsets_;
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
OrtMemType InputMemoryType(size_t input_index) const {
|
| 74 |
-
auto it = input_memory_type_args_.find(input_index);
|
| 75 |
-
if (it == input_memory_type_args_.end())
|
| 76 |
-
return default_inputs_mem_type_;
|
| 77 |
-
return it->second;
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
bool IsInputOnCpu(size_t input_index) const { return MemTypeOnCpuExplicitly(InputMemoryType(input_index)); }
|
| 81 |
-
|
| 82 |
-
bool IsOutputOnCpu(size_t output_index) const { return MemTypeOnCpuExplicitly(OutputMemoryType(output_index)); }
|
| 83 |
-
|
| 84 |
-
bool AllocateInputsContiguously() const { return allocate_inputs_contiguously_; }
|
| 85 |
-
|
| 86 |
-
bool HasExternalOutputs() const { return external_outputs_; }
|
| 87 |
-
|
| 88 |
-
#ifdef ENABLE_STRIDED_TENSORS
|
| 89 |
-
const std::vector<int>& MayStridedInput() const { return may_strided_inputs_; }
|
| 90 |
-
const std::vector<std::pair<int, int>>& MayStridedOutput() const { return may_strided_output_map_; }
|
| 91 |
-
#endif
|
| 92 |
-
|
| 93 |
-
OrtMemType OutputMemoryType(size_t output_index) const {
|
| 94 |
-
auto it = output_memory_type_args_.find(output_index);
|
| 95 |
-
if (it == output_memory_type_args_.end())
|
| 96 |
-
return default_outputs_mem_type_;
|
| 97 |
-
return it->second;
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
int ExecQueueId() const {
|
| 101 |
-
return exec_queue_id_;
|
| 102 |
-
}
|
| 103 |
-
|
| 104 |
-
bool IsConflict(const KernelDef& other) const;
|
| 105 |
-
|
| 106 |
-
private:
|
| 107 |
-
friend class KernelDefBuilder;
|
| 108 |
-
|
| 109 |
-
// The operator name supported by <*this> kernel..
|
| 110 |
-
std::string op_name_;
|
| 111 |
-
|
| 112 |
-
// The operator since_version range supported by <*this> kernel.
|
| 113 |
-
// A kernel could support an operator definition between <op_since_version_start>
|
| 114 |
-
// and <op_since_version_end> (inclusive).
|
| 115 |
-
int op_since_version_start_ = 1;
|
| 116 |
-
int op_since_version_end_ = INT_MAX;
|
| 117 |
-
|
| 118 |
-
// The operator domain supported by <*this> kernel.
|
| 119 |
-
// Default to 'onnxruntime::kOnnxDomain'.
|
| 120 |
-
// Please note the behavior of std::string("") and std::string() are different
|
| 121 |
-
std::string op_domain_;
|
| 122 |
-
|
| 123 |
-
// The type of the execution provider.
|
| 124 |
-
std::string provider_type_;
|
| 125 |
-
|
| 126 |
-
// The data types that are supported in this build (enabled) for inputs/outputs.
|
| 127 |
-
// Key is input/output/type constraint name defined in op schema, Value is supported types.
|
| 128 |
-
std::unordered_map<std::string, std::vector<MLDataType>> type_constraints_;
|
| 129 |
-
|
| 130 |
-
// An element <i, j> means that output j reuses the memory of input i.
|
| 131 |
-
std::vector<std::pair<int, int>> inplace_map_;
|
| 132 |
-
|
| 133 |
-
// An element <i, j> means that output j is an alias of input i.
|
| 134 |
-
std::vector<std::pair<int, int>> alias_map_;
|
| 135 |
-
|
| 136 |
-
// This variable stores <input_offset, output_offset> for the variadic alias mapping
|
| 137 |
-
// output 'i + output_offset' is an alias of input 'i + input_offset' for all i >= 0
|
| 138 |
-
std::optional<std::pair<int, int>> variadic_alias_offsets_;
|
| 139 |
-
|
| 140 |
-
// Require input tensors to be allocated contiguously.
|
| 141 |
-
bool allocate_inputs_contiguously_ = false;
|
| 142 |
-
|
| 143 |
-
// Whether the outputs are from external.
|
| 144 |
-
bool external_outputs_ = false;
|
| 145 |
-
|
| 146 |
-
#ifdef ENABLE_STRIDED_TENSORS
|
| 147 |
-
// An element i means i-th input can be strided tensor.
|
| 148 |
-
std::vector<int> may_strided_inputs_;
|
| 149 |
-
|
| 150 |
-
// An element <i, j> means j-th output can be a strided tensor, which share the data from i-th input.
|
| 151 |
-
std::vector<std::pair<int, int>> may_strided_output_map_;
|
| 152 |
-
#endif
|
| 153 |
-
|
| 154 |
-
// The memory types of inputs/outputs of this kernel
|
| 155 |
-
MemTypeMap input_memory_type_args_;
|
| 156 |
-
MemTypeMap output_memory_type_args_;
|
| 157 |
-
|
| 158 |
-
// execution command queue id, 0 for default queue in execution provider
|
| 159 |
-
int exec_queue_id_ = 0;
|
| 160 |
-
// Default memory type for all inputs
|
| 161 |
-
OrtMemType default_inputs_mem_type_{OrtMemTypeDefault};
|
| 162 |
-
// Default memory type for all outputs
|
| 163 |
-
OrtMemType default_outputs_mem_type_{OrtMemTypeDefault};
|
| 164 |
-
};
|
| 165 |
-
|
| 166 |
-
class KernelDefBuilder {
|
| 167 |
-
public:
|
| 168 |
-
static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }
|
| 169 |
-
|
| 170 |
-
explicit KernelDefBuilder()
|
| 171 |
-
: kernel_def_(std::make_unique<KernelDef>()) {}
|
| 172 |
-
|
| 173 |
-
KernelDefBuilder& SetName(const std::string& op_name);
|
| 174 |
-
KernelDefBuilder& SetName(const char* op_name);
|
| 175 |
-
|
| 176 |
-
KernelDefBuilder& SetDomain(const std::string& domain);
|
| 177 |
-
KernelDefBuilder& SetDomain(const char* domain);
|
| 178 |
-
|
| 179 |
-
/**
|
| 180 |
-
This kernel supports operator definition since <since_version> (to latest).
|
| 181 |
-
*/
|
| 182 |
-
KernelDefBuilder& SinceVersion(int since_version) {
|
| 183 |
-
kernel_def_->op_since_version_start_ = since_version;
|
| 184 |
-
return *this;
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
/**
|
| 188 |
-
The start and end version should be set accordingly per version range for
|
| 189 |
-
each domain registered in OpSchemaRegistry::DomainToVersionRange in
|
| 190 |
-
\onnxruntime\onnxruntime\core\graph\op.h as below.
|
| 191 |
-
Key: domain. Value: <lowest version, highest version> pair.
|
| 192 |
-
std::unordered_map<std::string, std::pair<int, int>> map_;
|
| 193 |
-
*/
|
| 194 |
-
KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
|
| 195 |
-
kernel_def_->op_since_version_start_ = since_version_start;
|
| 196 |
-
kernel_def_->op_since_version_end_ = since_version_end;
|
| 197 |
-
return *this;
|
| 198 |
-
}
|
| 199 |
-
|
| 200 |
-
/**
|
| 201 |
-
The execution provider type of the kernel.
|
| 202 |
-
*/
|
| 203 |
-
KernelDefBuilder& Provider(ProviderType provider_type);
|
| 204 |
-
KernelDefBuilder& Provider(const char* provider_type);
|
| 205 |
-
|
| 206 |
-
/**
|
| 207 |
-
Specify the set of types that this kernel supports. A further restriction
|
| 208 |
-
of the set of types specified in the op schema.
|
| 209 |
-
|
| 210 |
-
@param arg_name The arg name can be either op formal parameter name, say "X", or type
|
| 211 |
-
argument name specified in op schema, say "T".
|
| 212 |
-
@param types The types that are supported in this build.
|
| 213 |
-
*/
|
| 214 |
-
KernelDefBuilder& TypeConstraint(const std::string& arg_name, std::vector<MLDataType> types);
|
| 215 |
-
KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types);
|
| 216 |
-
|
| 217 |
-
/**
|
| 218 |
-
Like TypeConstraint but supports just a single type.
|
| 219 |
-
*/
|
| 220 |
-
KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType type);
|
| 221 |
-
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type);
|
| 222 |
-
|
| 223 |
-
/**
|
| 224 |
-
Inplace mapping from inputs to outputs allowed.
|
| 225 |
-
It means that uplayer runtime could do memory in-place optimization
|
| 226 |
-
as it will not impact the correctness of this kernel.
|
| 227 |
-
*/
|
| 228 |
-
KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces);
|
| 229 |
-
KernelDefBuilder& MayInplace(int input_index, int output_index);
|
| 230 |
-
|
| 231 |
-
/**
|
| 232 |
-
Alias mapping from inputs to outputs. Different from Inplace that the
|
| 233 |
-
content of the tensor is not changed. This is to take care of operators
|
| 234 |
-
such as Identity and Reshape.
|
| 235 |
-
*/
|
| 236 |
-
KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases);
|
| 237 |
-
KernelDefBuilder& Alias(int input_index, int output_index);
|
| 238 |
-
|
| 239 |
-
/**
|
| 240 |
-
Apply variadic number of alias mapping from inputs to outputs.
|
| 241 |
-
This is effectively applying Alias(i + input_offset, i + output_offset) for i >= 0
|
| 242 |
-
*/
|
| 243 |
-
KernelDefBuilder& VariadicAlias(int input_offset, int output_offset);
|
| 244 |
-
|
| 245 |
-
/**
|
| 246 |
-
Specify that this kernel requires input tensors to be allocated
|
| 247 |
-
contiguously. This allows kernels to execute as a single large
|
| 248 |
-
computation, rather than numerous smaller computations.
|
| 249 |
-
*/
|
| 250 |
-
KernelDefBuilder& AllocateInputsContiguously() {
|
| 251 |
-
kernel_def_->allocate_inputs_contiguously_ = true;
|
| 252 |
-
return *this;
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
/**
|
| 256 |
-
Specify that this kernel's output buffers are passed from external,
|
| 257 |
-
i.e. not created or managed by ORT's memory allocator.
|
| 258 |
-
*/
|
| 259 |
-
KernelDefBuilder& ExternalOutputs() {
|
| 260 |
-
kernel_def_->external_outputs_ = true;
|
| 261 |
-
return *this;
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
#ifdef ENABLE_STRIDED_TENSORS
|
| 265 |
-
/**
|
| 266 |
-
Specify that the input_index-th input can be strided tensor.
|
| 267 |
-
*/
|
| 268 |
-
KernelDefBuilder& MayStridedInput(int input_index);
|
| 269 |
-
|
| 270 |
-
/**
|
| 271 |
-
Specify that the output_index-th output can be strided tensor, and share the data
|
| 272 |
-
from input_index-th input.
|
| 273 |
-
*/
|
| 274 |
-
KernelDefBuilder& MayStridedOutput(int input_index, int output_index);
|
| 275 |
-
#endif
|
| 276 |
-
|
| 277 |
-
/**
|
| 278 |
-
Specify that this kernel requires an input arg
|
| 279 |
-
in certain memory type (instead of the default, device memory).
|
| 280 |
-
*/
|
| 281 |
-
KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
|
| 282 |
-
kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
|
| 283 |
-
return *this;
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
-
/**
|
| 287 |
-
Specify that this kernel requires input arguments
|
| 288 |
-
in certain memory type (instead of the default, device memory).
|
| 289 |
-
*/
|
| 290 |
-
KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
|
| 291 |
-
for (auto input_index : input_indexes) {
|
| 292 |
-
kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, type));
|
| 293 |
-
}
|
| 294 |
-
return *this;
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
/**
|
| 298 |
-
Specify that this kernel provides an output arg
|
| 299 |
-
in certain memory type (instead of the default, device memory).
|
| 300 |
-
*/
|
| 301 |
-
KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
|
| 302 |
-
kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
|
| 303 |
-
return *this;
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
/**
|
| 307 |
-
Specify that this kernel provides an output arguments
|
| 308 |
-
in certain memory type (instead of the default, device memory).
|
| 309 |
-
*/
|
| 310 |
-
KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {
|
| 311 |
-
for (auto output_index : output_indexes) {
|
| 312 |
-
kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, type));
|
| 313 |
-
}
|
| 314 |
-
return *this;
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
/**
|
| 318 |
-
Specify that this kernel runs on which execution queue in the provider
|
| 319 |
-
*/
|
| 320 |
-
KernelDefBuilder& ExecQueueId(int queue_id) {
|
| 321 |
-
kernel_def_->exec_queue_id_ = queue_id;
|
| 322 |
-
return *this;
|
| 323 |
-
}
|
| 324 |
-
|
| 325 |
-
/**
|
| 326 |
-
Specify the default inputs memory type, if not specified, it is DefaultMemory
|
| 327 |
-
*/
|
| 328 |
-
KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
|
| 329 |
-
kernel_def_->default_inputs_mem_type_ = mem_type;
|
| 330 |
-
return *this;
|
| 331 |
-
}
|
| 332 |
-
|
| 333 |
-
/**
|
| 334 |
-
Specify the default outputs memory type, if not specified, it is DefaultMemory
|
| 335 |
-
*/
|
| 336 |
-
KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
|
| 337 |
-
kernel_def_->default_outputs_mem_type_ = mem_type;
|
| 338 |
-
return *this;
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
/**
|
| 342 |
-
Return the kernel definition, passing ownership of the KernelDef to the caller
|
| 343 |
-
*/
|
| 344 |
-
std::unique_ptr<KernelDef> Build() {
|
| 345 |
-
return std::move(kernel_def_);
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
private:
|
| 349 |
-
// we own the KernelDef until Build() is called.
|
| 350 |
-
std::unique_ptr<KernelDef> kernel_def_;
|
| 351 |
-
};
|
| 352 |
-
|
| 353 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/kernel_registry.h
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string_view>
|
| 7 |
-
|
| 8 |
-
#include "core/framework/op_kernel.h"
|
| 9 |
-
|
| 10 |
-
namespace onnxruntime {
|
| 11 |
-
|
| 12 |
-
using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
|
| 13 |
-
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
|
| 14 |
-
|
| 15 |
-
class IKernelTypeStrResolver;
|
| 16 |
-
|
| 17 |
-
/**
|
| 18 |
-
* Each provider has a KernelRegistry. Often, the KernelRegistry only belongs to that specific provider.
|
| 19 |
-
*/
|
| 20 |
-
class KernelRegistry {
|
| 21 |
-
public:
|
| 22 |
-
KernelRegistry() = default;
|
| 23 |
-
|
| 24 |
-
// Register a kernel with kernel definition and function to create the kernel.
|
| 25 |
-
Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator);
|
| 26 |
-
|
| 27 |
-
Status Register(KernelCreateInfo&& create_info);
|
| 28 |
-
|
| 29 |
-
// TODO(edgchen1) for TryFindKernel(), consider using `out` != nullptr as indicator of whether kernel was found and
|
| 30 |
-
// Status as an indication of failure
|
| 31 |
-
|
| 32 |
-
// Check if an execution provider can create kernel for a node and return the kernel if so
|
| 33 |
-
Status TryFindKernel(const Node& node, ProviderType exec_provider,
|
| 34 |
-
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
| 35 |
-
const KernelCreateInfo** out) const;
|
| 36 |
-
|
| 37 |
-
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
|
| 38 |
-
ProviderType exec_provider,
|
| 39 |
-
const IKernelTypeStrResolver& kernel_type_str_resolver) {
|
| 40 |
-
const KernelCreateInfo* info;
|
| 41 |
-
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
|
| 42 |
-
return st.IsOK();
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 46 |
-
// Find KernelCreateInfo in instant mode
|
| 47 |
-
Status TryFindKernel(const std::string& op_name, const std::string& domain, const int& version,
|
| 48 |
-
const std::unordered_map<std::string, MLDataType>& type_constraints,
|
| 49 |
-
ProviderType exec_provider, const KernelCreateInfo** out) const;
|
| 50 |
-
#endif // !defined(ORT_MINIMAL_BUILD)
|
| 51 |
-
|
| 52 |
-
bool IsEmpty() const { return kernel_creator_fn_map_.empty(); }
|
| 53 |
-
|
| 54 |
-
#ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA
|
| 55 |
-
// This is used by the opkernel doc generator to enlist all registered operators for a given provider's opkernel
|
| 56 |
-
const KernelCreateMap& GetKernelCreateMap() const {
|
| 57 |
-
return kernel_creator_fn_map_;
|
| 58 |
-
}
|
| 59 |
-
#endif
|
| 60 |
-
|
| 61 |
-
private:
|
| 62 |
-
// Check whether the types of inputs/outputs of the given node match the extra
|
| 63 |
-
// type-constraints of the given kernel. This serves two purposes: first, to
|
| 64 |
-
// select the right kernel implementation based on the types of the arguments
|
| 65 |
-
// when we have multiple kernels, e.g., Clip<float> and Clip<int>; second, to
|
| 66 |
-
// accommodate (and check) mapping of ONNX (specification) type to the onnxruntime
|
| 67 |
-
// implementation type (e.g., if we want to implement ONNX's float16 as a regular
|
| 68 |
-
// float in onnxruntime). (The second, however, requires a globally uniform mapping.)
|
| 69 |
-
//
|
| 70 |
-
// Note that this is not intended for type-checking the node against the ONNX
|
| 71 |
-
// type specification of the corresponding op, which is done before this check.
|
| 72 |
-
static bool VerifyKernelDef(const Node& node,
|
| 73 |
-
const KernelDef& kernel_def,
|
| 74 |
-
const IKernelTypeStrResolver& kernel_type_str_resolver,
|
| 75 |
-
std::string& error_str);
|
| 76 |
-
|
| 77 |
-
static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) {
|
| 78 |
-
std::string key(op_name);
|
| 79 |
-
// use the kOnnxDomainAlias of 'ai.onnx' instead of kOnnxDomain's empty string
|
| 80 |
-
key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider);
|
| 81 |
-
return key;
|
| 82 |
-
}
|
| 83 |
-
|
| 84 |
-
static std::string GetMapKey(const KernelDef& kernel_def) {
|
| 85 |
-
return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider());
|
| 86 |
-
}
|
| 87 |
-
// Kernel create function map from op name to kernel creation info.
|
| 88 |
-
// key is opname+domain_name+provider_name
|
| 89 |
-
KernelCreateMap kernel_creator_fn_map_;
|
| 90 |
-
};
|
| 91 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/op_kernel.h
DELETED
|
@@ -1,387 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "boost/mp11.hpp"
|
| 7 |
-
|
| 8 |
-
// It is safe to include the below header even if SHARED_PROVIDER macro is enabled
|
| 9 |
-
// as it doesn't include any pb headers.
|
| 10 |
-
#include "core/framework/prepacked_weights_container.h"
|
| 11 |
-
|
| 12 |
-
#ifndef SHARED_PROVIDER
|
| 13 |
-
#include <functional>
|
| 14 |
-
#include "core/common/exceptions.h"
|
| 15 |
-
#include "core/common/logging/logging.h"
|
| 16 |
-
#include "core/common/status.h"
|
| 17 |
-
#include "core/framework/execution_provider.h"
|
| 18 |
-
#include "core/framework/kernel_def_builder.h"
|
| 19 |
-
#include "core/framework/ort_value.h"
|
| 20 |
-
#include "core/framework/op_kernel_info.h"
|
| 21 |
-
#include "core/framework/op_node_proto_helper.h"
|
| 22 |
-
#include "core/framework/tensor.h"
|
| 23 |
-
#include "core/framework/sparse_tensor.h"
|
| 24 |
-
#include "core/graph/constants.h"
|
| 25 |
-
#include "core/graph/graph_viewer.h"
|
| 26 |
-
#if !defined(ORT_MINIMAL_BUILD)
|
| 27 |
-
#include "onnx/defs/schema.h"
|
| 28 |
-
#else
|
| 29 |
-
#include "onnx/defs/data_type_utils.h"
|
| 30 |
-
#endif
|
| 31 |
-
#include "onnx/onnx_pb.h"
|
| 32 |
-
#include "onnx/onnx-operators_pb.h"
|
| 33 |
-
#include "core/common/gsl.h"
|
| 34 |
-
namespace onnxruntime {
|
| 35 |
-
class OpKernelContext;
|
| 36 |
-
}
|
| 37 |
-
#endif
|
| 38 |
-
|
| 39 |
-
namespace onnxruntime {
|
| 40 |
-
|
| 41 |
-
std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info);
|
| 42 |
-
|
| 43 |
-
class OpKernel {
|
| 44 |
-
public:
|
| 45 |
-
using DoneCallback = std::function<void()>;
|
| 46 |
-
|
| 47 |
-
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {}
|
| 48 |
-
virtual ~OpKernel() = default;
|
| 49 |
-
|
| 50 |
-
const onnxruntime::Node& Node() const;
|
| 51 |
-
const onnxruntime::KernelDef& KernelDef() const;
|
| 52 |
-
|
| 53 |
-
[[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0;
|
| 54 |
-
|
| 55 |
-
[[nodiscard]] virtual bool IsAsync() const {
|
| 56 |
-
// by default all kernels are sync version.
|
| 57 |
-
return false;
|
| 58 |
-
}
|
| 59 |
-
|
| 60 |
-
[[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const {
|
| 61 |
-
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
// Override this function to PrePack initialized constant tensor to the format as needed.
|
| 65 |
-
// For example, MatMul kernel can pack the input B if it is constant like code below.
|
| 66 |
-
// Status PrePack(const Tensor& tensor, int input_idx, /*out*/ bool& is_packed,
|
| 67 |
-
// /*out*/ PrePackedWeights* prepacked_weight_for_caching,
|
| 68 |
-
// AllocatorPtr alloc) override {
|
| 69 |
-
// is_packed = false;
|
| 70 |
-
// if (input_idx == 1) {
|
| 71 |
-
// is_packed = true;
|
| 72 |
-
// this.Pack(tensor, this.buffer_, alloc);
|
| 73 |
-
// if (prepacked_weight_for_caching) {
|
| 74 |
-
// // LOGIC TO CACHE `this.buffer_` SINCE THE KERNEL DOESN"T OWN THE PACKED WEIGHT
|
| 75 |
-
// }
|
| 76 |
-
// }
|
| 77 |
-
// return Status::OK();
|
| 78 |
-
// }
|
| 79 |
-
// Please refer to MatMulIntegerToFloatBase for a complete example
|
| 80 |
-
// @param tensor: The initialized constant tensor
|
| 81 |
-
// @param input_idx: The input index of the tensor in this kernel
|
| 82 |
-
// @param alloc: The kernel's PrePack() method MUST use this allocator for allocating the pre-packed
|
| 83 |
-
// weights' buffers. The alloc that the PrePack() method will receive will be either
|
| 84 |
-
// the allocator tied to the session if the kernel owns the pre-packed buffer or an
|
| 85 |
-
// allocator shared between sessions if the pre-packed buffer is to be shared across sessions
|
| 86 |
-
// (i.e.) the kernel does not own the buffer.
|
| 87 |
-
// @param is_packed: Set it to true if the kernel packed the tensor or to false
|
| 88 |
-
// The kernel is responsible for keeping the packed data and related metadata if is_packed is true,
|
| 89 |
-
// and the original initialized constant tensor will be released and not accessible anymore in
|
| 90 |
-
// the Compute function.
|
| 91 |
-
// @param prepacked_weights: A PrePackedWeights instance will be provided to the kernel IF the pre-packed weights
|
| 92 |
-
// are meant to be stored in a shared container.
|
| 93 |
-
|
| 94 |
-
virtual Status
|
| 95 |
-
PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/,
|
| 96 |
-
/*out*/ bool& is_packed, /*out*/ PrePackedWeights* /*prepacked_weights*/) {
|
| 97 |
-
is_packed = false;
|
| 98 |
-
return Status::OK();
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
// Override this function to use provided pre-packed weight.
|
| 102 |
-
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
|
| 103 |
-
// int input_idx,
|
| 104 |
-
// /*out*/ bool& used_shared_buffers) {
|
| 105 |
-
// used_shared_buffers = true;
|
| 106 |
-
// this.buffer_ = std::move(prepacked_buffers[0]);
|
| 107 |
-
// return Status::OK();
|
| 108 |
-
// }
|
| 109 |
-
// Please refer to MatMulIntegerToFloatBase for a complete example
|
| 110 |
-
// @param prepacked_buffers: The pre-packed buffers to be used by this kernel for the provided input index
|
| 111 |
-
// (Sometimes a single constant initializer may have multiple pre-packed buffers associated
|
| 112 |
-
// with it and it upto the kernel developer to store it in any order of their choice in PrePack()
|
| 113 |
-
// and must use the same order for retrieval in UseSharedPrePackedBuffers().
|
| 114 |
-
// @param input_idx: The input index of the tensor in this kernel
|
| 115 |
-
// @param used_shared_buffers: Boolean flag set by the kernel implementation indicating
|
| 116 |
-
// that the provided weight has been used by the kernel.
|
| 117 |
-
virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& /*prepacked_buffers*/,
|
| 118 |
-
int /*input_idx*/,
|
| 119 |
-
/*out*/ bool& used_shared_buffers) {
|
| 120 |
-
used_shared_buffers = false;
|
| 121 |
-
return Status::OK();
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
const OrtMemoryInfo& Allocator(OrtMemType mem_type) const;
|
| 125 |
-
const OpKernelInfo& Info() const {
|
| 126 |
-
return *op_kernel_info_;
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
private:
|
| 130 |
-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel);
|
| 131 |
-
std::unique_ptr<OpKernelInfo> op_kernel_info_;
|
| 132 |
-
};
|
| 133 |
-
class FuncManager;
|
| 134 |
-
using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>;
|
| 135 |
-
using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::type;
|
| 136 |
-
|
| 137 |
-
struct KernelCreateInfo {
|
| 138 |
-
std::unique_ptr<KernelDef> kernel_def; // Owned and stored in the global kernel registry.
|
| 139 |
-
KernelCreateFn kernel_create_func;
|
| 140 |
-
Status status;
|
| 141 |
-
|
| 142 |
-
KernelCreateInfo(std::unique_ptr<KernelDef> definition,
|
| 143 |
-
KernelCreateFn create_func)
|
| 144 |
-
: kernel_def(std::move(definition)),
|
| 145 |
-
kernel_create_func(create_func) {}
|
| 146 |
-
|
| 147 |
-
KernelCreateInfo(KernelCreateInfo&& other) noexcept
|
| 148 |
-
: kernel_def(std::move(other.kernel_def)),
|
| 149 |
-
kernel_create_func(std::move(other.kernel_create_func)) {}
|
| 150 |
-
|
| 151 |
-
KernelCreateInfo() = default;
|
| 152 |
-
};
|
| 153 |
-
|
| 154 |
-
// Forward declarations for the non-specialized BuildKernelCreateInfo method.
|
| 155 |
-
template <typename T>
|
| 156 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 157 |
-
|
| 158 |
-
namespace ml {
|
| 159 |
-
template <typename T>
|
| 160 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 161 |
-
} // namespace ml
|
| 162 |
-
|
| 163 |
-
namespace contrib {
|
| 164 |
-
template <typename T>
|
| 165 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 166 |
-
} // namespace contrib
|
| 167 |
-
|
| 168 |
-
namespace contrib {
|
| 169 |
-
namespace cuda {
|
| 170 |
-
template <typename T>
|
| 171 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 172 |
-
} // namespace cuda
|
| 173 |
-
} // namespace contrib
|
| 174 |
-
|
| 175 |
-
namespace contrib {
|
| 176 |
-
namespace rocm {
|
| 177 |
-
template <typename T>
|
| 178 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 179 |
-
} // namespace rocm
|
| 180 |
-
} // namespace contrib
|
| 181 |
-
|
| 182 |
-
namespace contrib {
|
| 183 |
-
namespace snpe {
|
| 184 |
-
template <typename T>
|
| 185 |
-
KernelCreateInfo BuildKernelCreateInfo();
|
| 186 |
-
} // namespace snpe
|
| 187 |
-
} // namespace contrib
|
| 188 |
-
|
| 189 |
-
using BuildKernelCreateInfoFn = KernelCreateInfo (*)();
|
| 190 |
-
|
| 191 |
-
// Naming convention for operator kernel classes
|
| 192 |
-
#define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \
|
| 193 |
-
provider##_##name##_##domain##_ver##ver
|
| 194 |
-
|
| 195 |
-
#define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \
|
| 196 |
-
ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 197 |
-
|
| 198 |
-
#define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \
|
| 199 |
-
ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 200 |
-
|
| 201 |
-
#define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \
|
| 202 |
-
ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 203 |
-
|
| 204 |
-
#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \
|
| 205 |
-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \
|
| 206 |
-
template <> \
|
| 207 |
-
KernelCreateInfo \
|
| 208 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \
|
| 209 |
-
return KernelCreateInfo( \
|
| 210 |
-
builder.SetName(#name) \
|
| 211 |
-
.SetDomain(domain) \
|
| 212 |
-
.SinceVersion(ver) \
|
| 213 |
-
.Provider(provider) \
|
| 214 |
-
.Build(), \
|
| 215 |
-
static_cast<KernelCreatePtrFn>( \
|
| 216 |
-
[](FuncManager&, \
|
| 217 |
-
const OpKernelInfo& info, \
|
| 218 |
-
std::unique_ptr<OpKernel>& out) -> Status { \
|
| 219 |
-
out = std::make_unique<__VA_ARGS__>(info); \
|
| 220 |
-
return Status::OK(); \
|
| 221 |
-
})); \
|
| 222 |
-
}
|
| 223 |
-
|
| 224 |
-
#define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \
|
| 225 |
-
provider##_##name##_##domain##_ver##startver##_##endver
|
| 226 |
-
|
| 227 |
-
#define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \
|
| 228 |
-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 229 |
-
|
| 230 |
-
#define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \
|
| 231 |
-
ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 232 |
-
|
| 233 |
-
#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \
|
| 234 |
-
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \
|
| 235 |
-
template <> \
|
| 236 |
-
KernelCreateInfo \
|
| 237 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \
|
| 238 |
-
return KernelCreateInfo( \
|
| 239 |
-
builder.SetName(#name) \
|
| 240 |
-
.SetDomain(domain) \
|
| 241 |
-
.SinceVersion(startver, endver) \
|
| 242 |
-
.Provider(provider) \
|
| 243 |
-
.Build(), \
|
| 244 |
-
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
#define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \
|
| 248 |
-
provider##_##name##_##domain##_ver##ver##_##type
|
| 249 |
-
|
| 250 |
-
#define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \
|
| 251 |
-
ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 252 |
-
|
| 253 |
-
#define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \
|
| 254 |
-
ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 255 |
-
|
| 256 |
-
#define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \
|
| 257 |
-
ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__)
|
| 258 |
-
|
| 259 |
-
#define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \
|
| 260 |
-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
|
| 261 |
-
template <> \
|
| 262 |
-
KernelCreateInfo \
|
| 263 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \
|
| 264 |
-
return KernelCreateInfo( \
|
| 265 |
-
builder.SetName(#name) \
|
| 266 |
-
.SetDomain(domain) \
|
| 267 |
-
.SinceVersion(ver) \
|
| 268 |
-
.Provider(provider) \
|
| 269 |
-
.Build(), \
|
| 270 |
-
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
|
| 271 |
-
}
|
| 272 |
-
|
| 273 |
-
#define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \
|
| 274 |
-
provider##_##name##_##domain##_ver##ver##_##type1##_##type2
|
| 275 |
-
|
| 276 |
-
#define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \
|
| 277 |
-
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \
|
| 278 |
-
template <> \
|
| 279 |
-
KernelCreateInfo \
|
| 280 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \
|
| 281 |
-
return KernelCreateInfo( \
|
| 282 |
-
builder.SetName(#name) \
|
| 283 |
-
.SetDomain(domain) \
|
| 284 |
-
.SinceVersion(ver) \
|
| 285 |
-
.Provider(provider) \
|
| 286 |
-
.Build(), \
|
| 287 |
-
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
|
| 288 |
-
}
|
| 289 |
-
|
| 290 |
-
#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \
|
| 291 |
-
provider##_##name##_##domain##_ver##startver##_##endver##_##type
|
| 292 |
-
|
| 293 |
-
#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \
|
| 294 |
-
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \
|
| 295 |
-
__VA_ARGS__)
|
| 296 |
-
|
| 297 |
-
#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \
|
| 298 |
-
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \
|
| 299 |
-
__VA_ARGS__)
|
| 300 |
-
|
| 301 |
-
#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \
|
| 302 |
-
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \
|
| 303 |
-
__VA_ARGS__)
|
| 304 |
-
|
| 305 |
-
#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \
|
| 306 |
-
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \
|
| 307 |
-
template <> \
|
| 308 |
-
KernelCreateInfo \
|
| 309 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
|
| 310 |
-
type, name)>() { \
|
| 311 |
-
return KernelCreateInfo( \
|
| 312 |
-
builder.SetName(#name) \
|
| 313 |
-
.SetDomain(domain) \
|
| 314 |
-
.SinceVersion(startver, endver) \
|
| 315 |
-
.Provider(provider) \
|
| 316 |
-
.Build(), \
|
| 317 |
-
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
|
| 318 |
-
}
|
| 319 |
-
|
| 320 |
-
#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \
|
| 321 |
-
provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2
|
| 322 |
-
|
| 323 |
-
#define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \
|
| 324 |
-
provider, builder, ...) \
|
| 325 |
-
class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \
|
| 326 |
-
template <> \
|
| 327 |
-
KernelCreateInfo \
|
| 328 |
-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \
|
| 329 |
-
type1, type2, name)>() { \
|
| 330 |
-
return KernelCreateInfo( \
|
| 331 |
-
builder.SetName(#name) \
|
| 332 |
-
.SetDomain(domain) \
|
| 333 |
-
.SinceVersion(startver, endver) \
|
| 334 |
-
.Provider(provider) \
|
| 335 |
-
.Build(), \
|
| 336 |
-
static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \
|
| 337 |
-
}
|
| 338 |
-
|
| 339 |
-
template <typename... Types>
|
| 340 |
-
struct BuildKernelDefConstraintsImpl {
|
| 341 |
-
std::vector<MLDataType> operator()() const {
|
| 342 |
-
return {DataTypeImpl::GetTensorType<Types>()...};
|
| 343 |
-
}
|
| 344 |
-
};
|
| 345 |
-
|
| 346 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 347 |
-
template <typename... Types>
|
| 348 |
-
struct BuildKernelDefSparseConstraintsImpl {
|
| 349 |
-
std::vector<MLDataType> operator()() const {
|
| 350 |
-
return {DataTypeImpl::GetSparseTensorType<Types>()...};
|
| 351 |
-
}
|
| 352 |
-
};
|
| 353 |
-
#endif
|
| 354 |
-
|
| 355 |
-
// Use within macro definitions to create a custom vector of constraints.
|
| 356 |
-
// Example: #define REG_KERNEL(OP, VERSION, KERNEL_CLASS, Type, ...)
|
| 357 |
-
// .TypeConstraint("T", BuildKernelDefConstraints<Type, __VA_ARGS_>())
|
| 358 |
-
template <typename... Types>
|
| 359 |
-
inline std::vector<MLDataType> BuildKernelDefConstraints() {
|
| 360 |
-
return BuildKernelDefConstraintsImpl<Types...>{}();
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 364 |
-
template <typename... Types>
|
| 365 |
-
inline std::vector<MLDataType> BuildKernelDefSparseConstraints() {
|
| 366 |
-
return BuildKernelDefSparseConstraintsImpl<Types...>{}();
|
| 367 |
-
}
|
| 368 |
-
#endif
|
| 369 |
-
|
| 370 |
-
// version of BuildKernelDefConstraints() which takes a type list
|
| 371 |
-
template <typename L>
|
| 372 |
-
inline std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() {
|
| 373 |
-
return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}();
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 377 |
-
template <typename L>
|
| 378 |
-
inline std::vector<MLDataType> BuildKernelDefSparseConstraintsFromTypeList() {
|
| 379 |
-
return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}();
|
| 380 |
-
}
|
| 381 |
-
#endif
|
| 382 |
-
|
| 383 |
-
} // namespace onnxruntime
|
| 384 |
-
|
| 385 |
-
#ifndef SHARED_PROVIDER
|
| 386 |
-
#include "core/framework/op_kernel_context.h"
|
| 387 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/op_kernel_context.h
DELETED
|
@@ -1,237 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
namespace onnxruntime {
|
| 5 |
-
class IExecutionFrame;
|
| 6 |
-
class Stream;
|
| 7 |
-
namespace concurrency {
|
| 8 |
-
class ThreadPool;
|
| 9 |
-
}
|
| 10 |
-
|
| 11 |
-
class OpKernelContext {
|
| 12 |
-
public:
|
| 13 |
-
using ArgMap = std::unordered_map<std::string, size_t>;
|
| 14 |
-
|
| 15 |
-
OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel,
|
| 16 |
-
_In_ Stream* stream,
|
| 17 |
-
_In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger);
|
| 18 |
-
|
| 19 |
-
virtual ~OpKernelContext() = default;
|
| 20 |
-
|
| 21 |
-
/**
|
| 22 |
-
Return the number of inputs for a variadic argument.
|
| 23 |
-
@param arg_num The operator argument number.
|
| 24 |
-
@returns Number of inputs the argument has.
|
| 25 |
-
*/
|
| 26 |
-
virtual int NumVariadicInputs(size_t arg_num) const;
|
| 27 |
-
|
| 28 |
-
virtual MLDataType InputType(int index) const;
|
| 29 |
-
virtual MLDataType OutputType(int index) const;
|
| 30 |
-
|
| 31 |
-
const OrtValue* GetInputOrtValue(int index) const {
|
| 32 |
-
return GetInputMLValue(index);
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
template <typename T>
|
| 36 |
-
const T* Input(int index) const {
|
| 37 |
-
const OrtValue* p_ml_value = GetInputMLValue(index);
|
| 38 |
-
ORT_TRY {
|
| 39 |
-
return p_ml_value ? &(p_ml_value->Get<T>()) : nullptr;
|
| 40 |
-
}
|
| 41 |
-
ORT_CATCH(const std::exception& /*e*/) {
|
| 42 |
-
ORT_THROW("Missing Input: " + kernel_->Node().InputDefs()[index]->Name());
|
| 43 |
-
}
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
// Fetch a required input, enforcing that it is present.
|
| 47 |
-
template <typename T>
|
| 48 |
-
const T& RequiredInput(int index) const {
|
| 49 |
-
const T* input_ptr = Input<T>(index);
|
| 50 |
-
ORT_ENFORCE(input_ptr, "Required input at index ", index, " is not present.");
|
| 51 |
-
return *input_ptr;
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
// Fetch output (non-tensor) with specified index.
|
| 55 |
-
template <typename T>
|
| 56 |
-
T* Output(int index) {
|
| 57 |
-
if (index < 0 || index >= OutputCount())
|
| 58 |
-
return nullptr;
|
| 59 |
-
|
| 60 |
-
OrtValue* p_ml_value = GetOrCreateOutputMLValue(index);
|
| 61 |
-
return p_ml_value ? p_ml_value->GetMutable<T>() : nullptr;
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
// In the case that memory allocation has not been done for an output tensor,
|
| 65 |
-
// The memory allocation will be done on-the-fly with given tensor shape.
|
| 66 |
-
// Return nullptr if the output is an unused optional output.
|
| 67 |
-
Tensor* Output(int index, const TensorShape& shape);
|
| 68 |
-
Tensor* Output(int index, const std::vector<int64_t>& shape);
|
| 69 |
-
Tensor* Output(int index, const std::initializer_list<int64_t>& shape);
|
| 70 |
-
|
| 71 |
-
// Fetch a required tensor output, enforcing that it is present.
|
| 72 |
-
Tensor& RequiredOutput(int index, const TensorShape& shape) {
|
| 73 |
-
Tensor* output_ptr = Output(index, shape);
|
| 74 |
-
ORT_ENFORCE(output_ptr, "Required output at index ", index, " is not present.");
|
| 75 |
-
return *output_ptr;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 79 |
-
// Fetch a sparse-tensor output corresponding to the specified index.
|
| 80 |
-
// shape must specify the shape of the underlying dense-tensor.
|
| 81 |
-
// Memory allocation for the output may happen when this method is invoked,
|
| 82 |
-
// unless static optimization pre-allocates it.
|
| 83 |
-
SparseTensor* OutputSparse(int index, const TensorShape& shape);
|
| 84 |
-
#endif
|
| 85 |
-
|
| 86 |
-
#if !defined(DISABLE_OPTIONAL_TYPE)
|
| 87 |
-
// Use this API to output a "None" of a specific type (e.g. Tensor) at specified index
|
| 88 |
-
template <typename T>
|
| 89 |
-
void OutputOptionalWithoutData(int index) {
|
| 90 |
-
auto* output_ort_value = GetOutputMLValue(index);
|
| 91 |
-
|
| 92 |
-
auto type = DataTypeImpl::GetType<T>();
|
| 93 |
-
|
| 94 |
-
output_ort_value->Init(nullptr, // This OrtValue is "None" and has no data
|
| 95 |
-
type,
|
| 96 |
-
type->GetDeleteFunc());
|
| 97 |
-
}
|
| 98 |
-
#endif
|
| 99 |
-
|
| 100 |
-
// Retrieve indexed shape obtained from memory planning before actual
|
| 101 |
-
// computation. If the indexed shape cannot be inferred, this function returns
|
| 102 |
-
// false.
|
| 103 |
-
virtual bool TryGetInferredInputShape(int index, TensorShape& shape) const;
|
| 104 |
-
|
| 105 |
-
// Retrieve indexed shape obtained from memory planning before actual
|
| 106 |
-
// computation. If the indexed shape cannot be inferred, this function returns
|
| 107 |
-
// false.
|
| 108 |
-
virtual bool TryGetInferredOutputShape(int index, TensorShape& shape) const;
|
| 109 |
-
|
| 110 |
-
const logging::Logger& Logger() const {
|
| 111 |
-
return *logger_;
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
// always >= 0
|
| 115 |
-
virtual int InputCount() const {
|
| 116 |
-
return static_cast<int>(kernel_->Node().InputDefs().size());
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
// always >= 0
|
| 120 |
-
virtual int ImplicitInputCount() const {
|
| 121 |
-
return static_cast<int>(kernel_->Node().ImplicitInputDefs().size());
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
// always >= 0
|
| 125 |
-
virtual int OutputCount() const {
|
| 126 |
-
return static_cast<int>(kernel_->Node().OutputDefs().size());
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
/**
|
| 130 |
-
Return an allocator on device 0, with memtype of OrtMemTypeDefault.
|
| 131 |
-
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
|
| 132 |
-
*/
|
| 133 |
-
[[nodiscard]] virtual Status GetTempSpaceAllocator(AllocatorPtr* output) const;
|
| 134 |
-
|
| 135 |
-
/**
|
| 136 |
-
Return the allocator associated with the CPU EP with memtype of OrtMemTypeDefault.
|
| 137 |
-
@remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc.
|
| 138 |
-
*/
|
| 139 |
-
[[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const;
|
| 140 |
-
|
| 141 |
-
/**
|
| 142 |
-
Return the device id that current kernel runs on.
|
| 143 |
-
*/
|
| 144 |
-
virtual int GetDeviceId() const {
|
| 145 |
-
return kernel_->Info().GetExecutionProvider()->GetDeviceId();
|
| 146 |
-
}
|
| 147 |
-
|
| 148 |
-
/**
|
| 149 |
-
Return the compute stream associated with the EP that the kernel is partitioned to.
|
| 150 |
-
For EPs that do not have a compute stream (e.g. CPU EP), a nullptr is returned.
|
| 151 |
-
*/
|
| 152 |
-
[[nodiscard]] virtual Stream* GetComputeStream() const {
|
| 153 |
-
return stream_;
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
/**
|
| 157 |
-
Returns the opset domain of the underlying kernel
|
| 158 |
-
**/
|
| 159 |
-
const std::string& GetOpDomain() const;
|
| 160 |
-
|
| 161 |
-
/**
|
| 162 |
-
Returns the optype of the underlying kernel
|
| 163 |
-
**/
|
| 164 |
-
const std::string& GetOpType() const;
|
| 165 |
-
|
| 166 |
-
/**
|
| 167 |
-
Returns the node name of the underlying kernel
|
| 168 |
-
**/
|
| 169 |
-
const std::string& GetNodeName() const;
|
| 170 |
-
|
| 171 |
-
/**
|
| 172 |
-
Returns the intra-op threadpool, if available.
|
| 173 |
-
*/
|
| 174 |
-
_Ret_maybenull_ onnxruntime::concurrency::ThreadPool* GetOperatorThreadPool() const { return threadpool_; }
|
| 175 |
-
|
| 176 |
-
/**
|
| 177 |
-
Returns whether deterministic computation is preferred.
|
| 178 |
-
*/
|
| 179 |
-
virtual bool GetUseDeterministicCompute() const {
|
| 180 |
-
return true;
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
protected:
|
| 184 |
-
OpKernelContext(concurrency::ThreadPool* threadpool, const logging::Logger& logger, Stream* stream);
|
| 185 |
-
|
| 186 |
-
onnxruntime::NodeIndex GetNodeIndex() const;
|
| 187 |
-
|
| 188 |
-
virtual const OrtValue* GetInputMLValue(int index) const;
|
| 189 |
-
const OrtValue* GetImplicitInputMLValue(int index) const;
|
| 190 |
-
OrtValue* GetOutputMLValue(int index);
|
| 191 |
-
|
| 192 |
-
#ifdef ENABLE_ATEN
|
| 193 |
-
Status SetOutputMLValue(int index, const OrtValue& ort_value);
|
| 194 |
-
#endif
|
| 195 |
-
|
| 196 |
-
// Creates the OrtValue* based on the shape, if it does not exist
|
| 197 |
-
virtual OrtValue* OutputMLValue(int index, const TensorShape& shape);
|
| 198 |
-
|
| 199 |
-
virtual OrtValue* GetOrCreateOutputMLValue(int index);
|
| 200 |
-
|
| 201 |
-
private:
|
| 202 |
-
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OpKernelContext);
|
| 203 |
-
int GetInputArgIndex(int index) const;
|
| 204 |
-
int GetImplicitInputArgIndex(int index) const;
|
| 205 |
-
int GetOutputArgIndex(int index) const;
|
| 206 |
-
|
| 207 |
-
IExecutionFrame* const execution_frame_{};
|
| 208 |
-
const OpKernel* const kernel_{};
|
| 209 |
-
concurrency::ThreadPool* const threadpool_{};
|
| 210 |
-
const logging::Logger* const logger_{};
|
| 211 |
-
|
| 212 |
-
// The argument starting index in ExecutionFrame.
|
| 213 |
-
int node_input_start_index_{-1};
|
| 214 |
-
int node_implicit_input_start_index_{-1};
|
| 215 |
-
int node_output_start_index_{-1};
|
| 216 |
-
|
| 217 |
-
Stream* stream_;
|
| 218 |
-
};
|
| 219 |
-
|
| 220 |
-
// Fetching output tensor without shape is not allowed except when it already exists
|
| 221 |
-
template <>
|
| 222 |
-
inline Tensor* OpKernelContext::Output<Tensor>(int index) {
|
| 223 |
-
OrtValue* p_ml_value = GetOutputMLValue(index);
|
| 224 |
-
ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape.");
|
| 225 |
-
return p_ml_value->GetMutable<Tensor>();
|
| 226 |
-
}
|
| 227 |
-
|
| 228 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 229 |
-
template <>
|
| 230 |
-
inline SparseTensor* OpKernelContext::Output<SparseTensor>(int index) {
|
| 231 |
-
OrtValue* p_ml_value = GetOutputMLValue(index);
|
| 232 |
-
ORT_ENFORCE(p_ml_value, "Please fetch output sparse tensor with specified shape.");
|
| 233 |
-
return p_ml_value->GetMutable<SparseTensor>();
|
| 234 |
-
}
|
| 235 |
-
#endif
|
| 236 |
-
|
| 237 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/op_kernel_info.h
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include "core/framework/execution_provider.h"
|
| 7 |
-
#include "core/framework/kernel_def_builder.h"
|
| 8 |
-
#include "core/framework/ort_value.h"
|
| 9 |
-
#include "core/framework/op_node_proto_helper.h"
|
| 10 |
-
#include "core/graph/graph_viewer.h"
|
| 11 |
-
#include "core/common/gsl.h"
|
| 12 |
-
|
| 13 |
-
namespace onnxruntime {
|
| 14 |
-
|
| 15 |
-
class OrtValueNameIdxMap;
|
| 16 |
-
class FuncManager;
|
| 17 |
-
class DataTransferManager;
|
| 18 |
-
struct AllocPlanPerValue;
|
| 19 |
-
|
| 20 |
-
// A very light-weight class, which works as an aggregated
|
| 21 |
-
// view of all data needed for constructing a Kernel instance.
|
| 22 |
-
// NOTE: it does not own/hold any objects.
|
| 23 |
-
class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
|
| 24 |
-
public:
|
| 25 |
-
explicit OpKernelInfo(const onnxruntime::Node& node,
|
| 26 |
-
const KernelDef& kernel_def,
|
| 27 |
-
const IExecutionProvider& execution_provider,
|
| 28 |
-
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
|
| 29 |
-
const OrtValueNameIdxMap& mlvalue_name_idx_map,
|
| 30 |
-
const DataTransferManager& data_transfer_mgr);
|
| 31 |
-
|
| 32 |
-
OpKernelInfo(const OpKernelInfo& other);
|
| 33 |
-
|
| 34 |
-
const OrtMemoryInfo& GetMemoryInfo(OrtMemType mem_type) const;
|
| 35 |
-
|
| 36 |
-
AllocatorPtr GetAllocator(OrtMemType mem_type) const;
|
| 37 |
-
|
| 38 |
-
const KernelDef& GetKernelDef() const;
|
| 39 |
-
|
| 40 |
-
const IExecutionProvider* GetExecutionProvider() const noexcept;
|
| 41 |
-
|
| 42 |
-
const DataTransferManager& GetDataTransferManager() const noexcept;
|
| 43 |
-
|
| 44 |
-
const onnxruntime::Node& node() const noexcept;
|
| 45 |
-
|
| 46 |
-
bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const;
|
| 47 |
-
|
| 48 |
-
private:
|
| 49 |
-
ORT_DISALLOW_MOVE(OpKernelInfo);
|
| 50 |
-
ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
|
| 51 |
-
|
| 52 |
-
const onnxruntime::Node& node_;
|
| 53 |
-
const KernelDef& kernel_def_;
|
| 54 |
-
// For non cpu/cuda case, this pointer should be set so that function kernel
|
| 55 |
-
// will delegate kernel compute call to <execution_provider> compute call.
|
| 56 |
-
gsl::not_null<const ::onnxruntime::IExecutionProvider*> execution_provider_;
|
| 57 |
-
const std::unordered_map<int, OrtValue>& constant_initialized_tensors_;
|
| 58 |
-
const OrtValueNameIdxMap& ort_value_name_idx_map_;
|
| 59 |
-
const DataTransferManager& data_transfer_mgr_;
|
| 60 |
-
ProtoHelperNodeContext proto_helper_context_;
|
| 61 |
-
};
|
| 62 |
-
|
| 63 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/op_node_proto_helper.h
DELETED
|
@@ -1,167 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#ifndef SHARED_PROVIDER
|
| 7 |
-
#include "core/common/status.h"
|
| 8 |
-
#include "core/framework/tensor_shape.h"
|
| 9 |
-
#include "core/graph/graph_viewer.h"
|
| 10 |
-
#include "core/common/gsl.h"
|
| 11 |
-
#endif
|
| 12 |
-
|
| 13 |
-
#ifdef __has_attribute
|
| 14 |
-
#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x)
|
| 15 |
-
#else
|
| 16 |
-
#define ORT_HAVE_ATTRIBUTE(x) 0
|
| 17 |
-
#endif
|
| 18 |
-
|
| 19 |
-
#if ORT_HAVE_ATTRIBUTE(nodiscard)
|
| 20 |
-
#define MUST_USE_RESULT [[nodiscard]]
|
| 21 |
-
#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result)
|
| 22 |
-
#define MUST_USE_RESULT __attribute__((warn_unused_result))
|
| 23 |
-
#else
|
| 24 |
-
#define MUST_USE_RESULT
|
| 25 |
-
#endif
|
| 26 |
-
|
| 27 |
-
class IMLOpKernel;
|
| 28 |
-
|
| 29 |
-
namespace onnxruntime {
|
| 30 |
-
|
| 31 |
-
/**
|
| 32 |
-
A set of wrappers with common signatures for use with both OpKernelInfo
|
| 33 |
-
(as its base class) and InferenceContext. Used by ABI kernels for both
|
| 34 |
-
shape / type inference and kernel construction
|
| 35 |
-
*/
|
| 36 |
-
template <class Impl_t>
|
| 37 |
-
class OpNodeProtoHelper {
|
| 38 |
-
public:
|
| 39 |
-
explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {}
|
| 40 |
-
|
| 41 |
-
/**
|
| 42 |
-
Get a single attribute
|
| 43 |
-
Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema
|
| 44 |
-
*/
|
| 45 |
-
template <typename T>
|
| 46 |
-
MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const;
|
| 47 |
-
|
| 48 |
-
/**
|
| 49 |
-
Get a single attribute
|
| 50 |
-
Call this function only when a default value for an optional attribute isn't specified in the op schema
|
| 51 |
-
*/
|
| 52 |
-
template <typename T>
|
| 53 |
-
T GetAttrOrDefault(const std::string& name, const T& default_value) const {
|
| 54 |
-
T tmp;
|
| 55 |
-
return GetAttr<T>(name, &tmp).IsOK() ? tmp : default_value;
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
/**
|
| 59 |
-
Get a single attribute
|
| 60 |
-
Call this function only when a default value for an optional attribute isn't specified in the op schema
|
| 61 |
-
*/
|
| 62 |
-
template <typename T>
|
| 63 |
-
void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const {
|
| 64 |
-
if (!GetAttr<T>(name, value).IsOK())
|
| 65 |
-
*value = default_value;
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
/**
|
| 69 |
-
Get repeated attributes
|
| 70 |
-
Call this function only when a default value for an optional attribute isn't specified in the op schema
|
| 71 |
-
*/
|
| 72 |
-
template <typename T>
|
| 73 |
-
MUST_USE_RESULT std::vector<T> GetAttrsOrDefault(const std::string& name, const std::vector<T>& default_value = std::vector<T>{}) const {
|
| 74 |
-
std::vector<T> tmp;
|
| 75 |
-
return GetAttrs<T>(name, tmp).IsOK() ? tmp : default_value;
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
/// <summary>
|
| 79 |
-
/// Return a gsl::span that points to an array of primitive types held by AttributeProto
|
| 80 |
-
/// This function allows to avoid copying big attributes locally into a kernel and operate on
|
| 81 |
-
/// AttributeProto data directly.
|
| 82 |
-
///
|
| 83 |
-
/// Does not apply to strings, Tensors and Sparse Tensors that require special treatment.
|
| 84 |
-
/// </summary>
|
| 85 |
-
/// <typeparam name="T">Primitive type contained in the array</typeparam>
|
| 86 |
-
/// <param name="name">Attribute name</param>
|
| 87 |
-
/// <param name="values">Attribute data in a span, out parameter</param>
|
| 88 |
-
/// <returns>Status</returns>
|
| 89 |
-
template <typename T>
|
| 90 |
-
MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span<const T>& values) const;
|
| 91 |
-
|
| 92 |
-
MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const;
|
| 93 |
-
|
| 94 |
-
MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const {
|
| 95 |
-
TensorShapeVector tmp;
|
| 96 |
-
return GetAttrs(name, tmp).IsOK() ? tmp : default_value;
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
/**
|
| 100 |
-
Get repeated attributes
|
| 101 |
-
*/
|
| 102 |
-
template <typename T>
|
| 103 |
-
MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector<T>& values) const;
|
| 104 |
-
|
| 105 |
-
template <typename T>
|
| 106 |
-
MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span<T> values) const;
|
| 107 |
-
|
| 108 |
-
MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name,
|
| 109 |
-
std::vector<std::reference_wrapper<const std::string>>& refs) const;
|
| 110 |
-
|
| 111 |
-
uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type,
|
| 112 |
-
const std::string& name) const noexcept;
|
| 113 |
-
|
| 114 |
-
bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type,
|
| 115 |
-
const std::string& name) const noexcept;
|
| 116 |
-
|
| 117 |
-
uint32_t GetInputCount() const {
|
| 118 |
-
return gsl::narrow_cast<uint32_t>(impl_->getNumInputs());
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
uint32_t GetOutputCount() const {
|
| 122 |
-
return gsl::narrow_cast<uint32_t>(impl_->getNumOutputs());
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const {
|
| 126 |
-
return impl_->getInputType(index);
|
| 127 |
-
}
|
| 128 |
-
|
| 129 |
-
const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const {
|
| 130 |
-
// Work around lack of a const method from the onnx InferenceContext interface
|
| 131 |
-
return const_cast<Impl_t*>(impl_)->getOutputType(index);
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
// Try to query an attribute, returning nullptr if it doesn't exist
|
| 135 |
-
const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const {
|
| 136 |
-
return impl_->getAttribute(name);
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const {
|
| 140 |
-
const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name);
|
| 141 |
-
ORT_ENFORCE(attr != nullptr);
|
| 142 |
-
return attr;
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
private:
|
| 146 |
-
OpNodeProtoHelper() = delete;
|
| 147 |
-
const Impl_t* impl_ = nullptr;
|
| 148 |
-
};
|
| 149 |
-
|
| 150 |
-
// The methods on the following class are called by OpNodeProtoHelper, implementing
|
| 151 |
-
// the same signatures as InferenceContext other than const-ness.
|
| 152 |
-
class ProtoHelperNodeContext {
|
| 153 |
-
public:
|
| 154 |
-
explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {}
|
| 155 |
-
ProtoHelperNodeContext() = delete;
|
| 156 |
-
|
| 157 |
-
const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const;
|
| 158 |
-
size_t getNumInputs() const;
|
| 159 |
-
const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const;
|
| 160 |
-
size_t getNumOutputs() const;
|
| 161 |
-
const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const;
|
| 162 |
-
|
| 163 |
-
private:
|
| 164 |
-
const onnxruntime::Node& node_;
|
| 165 |
-
};
|
| 166 |
-
|
| 167 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/ort_value.h
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string>
|
| 7 |
-
#ifndef SHARED_PROVIDER
|
| 8 |
-
#include "core/common/common.h"
|
| 9 |
-
#include "core/common/exceptions.h"
|
| 10 |
-
#include "core/framework/allocator.h"
|
| 11 |
-
#include "core/framework/data_types.h"
|
| 12 |
-
#include "core/framework/tensor.h"
|
| 13 |
-
|
| 14 |
-
namespace onnxruntime {
|
| 15 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 16 |
-
class SparseTensor;
|
| 17 |
-
#endif
|
| 18 |
-
class TensorSeq;
|
| 19 |
-
} // namespace onnxruntime
|
| 20 |
-
|
| 21 |
-
#endif
|
| 22 |
-
|
| 23 |
-
/**
|
| 24 |
-
Represents both tensors and non-tensors.
|
| 25 |
-
*/
|
| 26 |
-
struct OrtValue {
|
| 27 |
-
public:
|
| 28 |
-
OrtValue() : data_(nullptr) {}
|
| 29 |
-
~OrtValue() = default;
|
| 30 |
-
|
| 31 |
-
OrtValue(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) {
|
| 32 |
-
Init(pData, type, deleter);
|
| 33 |
-
}
|
| 34 |
-
|
| 35 |
-
void Init(void* pData, onnxruntime::MLDataType type, onnxruntime::DeleteFunc deleter) {
|
| 36 |
-
data_.reset(pData, deleter);
|
| 37 |
-
type_ = type;
|
| 38 |
-
}
|
| 39 |
-
|
| 40 |
-
void Init(void* pData, onnxruntime::MLDataType type, const std::function<void(void*)>& deleter) {
|
| 41 |
-
data_.reset(pData, deleter);
|
| 42 |
-
type_ = type;
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
bool IsAllocated() const {
|
| 46 |
-
return data_ && type_;
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
template <typename T>
|
| 50 |
-
const T& Get() const {
|
| 51 |
-
ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
|
| 52 |
-
return *static_cast<T*>(data_.get());
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
// May return nullptr, if this OrtValue is an optional type and it is "None".
|
| 56 |
-
template <typename T>
|
| 57 |
-
T* GetMutable() {
|
| 58 |
-
ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
|
| 59 |
-
return static_cast<T*>(data_.get());
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
bool IsTensor() const noexcept {
|
| 63 |
-
return (type_ != nullptr && type_->IsTensorType());
|
| 64 |
-
}
|
| 65 |
-
|
| 66 |
-
bool IsTensorSequence() const noexcept {
|
| 67 |
-
return (type_ != nullptr && type_->IsTensorSequenceType());
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
bool IsSparseTensor() const {
|
| 71 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 72 |
-
return (type_ != nullptr && type_->IsSparseTensorType());
|
| 73 |
-
#else
|
| 74 |
-
ORT_THROW("Sparse tensor is not supported in this build.");
|
| 75 |
-
#endif
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
onnxruntime::MLDataType Type() const {
|
| 79 |
-
return type_;
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
private:
|
| 83 |
-
std::shared_ptr<void> data_;
|
| 84 |
-
onnxruntime::MLDataType type_{nullptr};
|
| 85 |
-
};
|
| 86 |
-
|
| 87 |
-
template <>
|
| 88 |
-
inline const onnxruntime::Tensor& OrtValue::Get<onnxruntime::Tensor>() const {
|
| 89 |
-
ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 90 |
-
return *static_cast<onnxruntime::Tensor*>(data_.get());
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
template <>
|
| 94 |
-
inline onnxruntime::Tensor* OrtValue::GetMutable<onnxruntime::Tensor>() {
|
| 95 |
-
ORT_ENFORCE(IsTensor(), "Trying to get a Tensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 96 |
-
return static_cast<onnxruntime::Tensor*>(data_.get());
|
| 97 |
-
}
|
| 98 |
-
|
| 99 |
-
template <>
|
| 100 |
-
inline const onnxruntime::TensorSeq& OrtValue::Get<onnxruntime::TensorSeq>() const {
|
| 101 |
-
ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 102 |
-
return *static_cast<onnxruntime::TensorSeq*>(data_.get());
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
template <>
|
| 106 |
-
inline onnxruntime::TensorSeq* OrtValue::GetMutable<onnxruntime::TensorSeq>() {
|
| 107 |
-
ORT_ENFORCE(IsTensorSequence(), "Trying to get a TensorSeq, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 108 |
-
return static_cast<onnxruntime::TensorSeq*>(data_.get());
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
#if !defined(DISABLE_SPARSE_TENSORS)
|
| 112 |
-
template <>
|
| 113 |
-
inline const onnxruntime::SparseTensor& OrtValue::Get<onnxruntime::SparseTensor>() const {
|
| 114 |
-
ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 115 |
-
return *static_cast<onnxruntime::SparseTensor*>(data_.get());
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
template <>
|
| 119 |
-
inline onnxruntime::SparseTensor* OrtValue::GetMutable<onnxruntime::SparseTensor>() {
|
| 120 |
-
ORT_ENFORCE(IsSparseTensor(), "Trying to get a SparseTensor, but got: ", onnxruntime::DataTypeImpl::ToString(type_));
|
| 121 |
-
return static_cast<onnxruntime::SparseTensor*>(data_.get());
|
| 122 |
-
}
|
| 123 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/ortdevice.h
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <sstream>
|
| 7 |
-
|
| 8 |
-
// Struct to represent a physical device.
|
| 9 |
-
struct OrtDevice {
|
| 10 |
-
using DeviceType = int8_t;
|
| 11 |
-
using MemoryType = int8_t;
|
| 12 |
-
using DeviceId = int16_t;
|
| 13 |
-
|
| 14 |
-
// Pre-defined device types.
|
| 15 |
-
static const DeviceType CPU = 0;
|
| 16 |
-
static const DeviceType GPU = 1; // Nvidia or AMD
|
| 17 |
-
static const DeviceType FPGA = 2;
|
| 18 |
-
static const DeviceType NPU = 3; // Ascend
|
| 19 |
-
|
| 20 |
-
struct MemType {
|
| 21 |
-
// Pre-defined memory types.
|
| 22 |
-
static const MemoryType DEFAULT = 0;
|
| 23 |
-
static const MemoryType CUDA_PINNED = 1;
|
| 24 |
-
static const MemoryType HIP_PINNED = 2;
|
| 25 |
-
static const MemoryType CANN_PINNED = 3;
|
| 26 |
-
};
|
| 27 |
-
|
| 28 |
-
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
|
| 29 |
-
: device_type(device_type_),
|
| 30 |
-
memory_type(memory_type_),
|
| 31 |
-
device_id(device_id_) {}
|
| 32 |
-
|
| 33 |
-
constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
|
| 34 |
-
|
| 35 |
-
DeviceType Type() const {
|
| 36 |
-
return device_type;
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
MemoryType MemType() const {
|
| 40 |
-
return memory_type;
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
DeviceId Id() const {
|
| 44 |
-
return device_id;
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
std::string ToString() const {
|
| 48 |
-
std::ostringstream ostr;
|
| 49 |
-
ostr << "Device:["
|
| 50 |
-
<< "DeviceType:" << static_cast<int>(device_type)
|
| 51 |
-
<< " MemoryType:" << static_cast<int>(memory_type)
|
| 52 |
-
<< " DeviceId:" << device_id
|
| 53 |
-
<< "]";
|
| 54 |
-
return ostr.str();
|
| 55 |
-
}
|
| 56 |
-
|
| 57 |
-
private:
|
| 58 |
-
// Device type.
|
| 59 |
-
DeviceType device_type;
|
| 60 |
-
|
| 61 |
-
// Memory type.
|
| 62 |
-
MemoryType memory_type;
|
| 63 |
-
|
| 64 |
-
// Device index.
|
| 65 |
-
DeviceId device_id;
|
| 66 |
-
};
|
| 67 |
-
|
| 68 |
-
inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
|
| 69 |
-
return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
inline bool operator!=(const OrtDevice& left, const OrtDevice& other) {
|
| 73 |
-
return !(left == other);
|
| 74 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/ortmemoryinfo.h
DELETED
|
@@ -1,87 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string_view>
|
| 7 |
-
|
| 8 |
-
#include "core/common/hash_combine.h"
|
| 9 |
-
|
| 10 |
-
struct OrtMemoryInfo {
|
| 11 |
-
OrtMemoryInfo() = default; // to allow default construction of Tensor
|
| 12 |
-
|
| 13 |
-
// use string for name, so we could have customized allocator in execution provider.
|
| 14 |
-
const char* name = nullptr;
|
| 15 |
-
int id = -1;
|
| 16 |
-
OrtMemType mem_type = OrtMemTypeDefault;
|
| 17 |
-
OrtAllocatorType alloc_type = OrtInvalidAllocator;
|
| 18 |
-
OrtDevice device;
|
| 19 |
-
|
| 20 |
-
constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0,
|
| 21 |
-
OrtMemType mem_type_ = OrtMemTypeDefault)
|
| 22 |
-
#if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__))
|
| 23 |
-
// this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5
|
| 24 |
-
__attribute__((nonnull))
|
| 25 |
-
#endif
|
| 26 |
-
: name(name_),
|
| 27 |
-
id(id_),
|
| 28 |
-
mem_type(mem_type_),
|
| 29 |
-
alloc_type(type_),
|
| 30 |
-
device(device_) {
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
// To make OrtMemoryInfo become a valid key in std map
|
| 34 |
-
bool operator<(const OrtMemoryInfo& other) const {
|
| 35 |
-
if (alloc_type != other.alloc_type)
|
| 36 |
-
return alloc_type < other.alloc_type;
|
| 37 |
-
if (mem_type != other.mem_type)
|
| 38 |
-
return mem_type < other.mem_type;
|
| 39 |
-
if (id != other.id)
|
| 40 |
-
return id < other.id;
|
| 41 |
-
|
| 42 |
-
return strcmp(name, other.name) < 0;
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
// This is to make OrtMemoryInfo a valid key in hash tables
|
| 46 |
-
// we ignore device id
|
| 47 |
-
size_t Hash() const {
|
| 48 |
-
auto h = std::hash<int>()(alloc_type);
|
| 49 |
-
onnxruntime::HashCombine(mem_type, h);
|
| 50 |
-
onnxruntime::HashCombine(id, h);
|
| 51 |
-
onnxruntime::HashCombine<std::string_view>(name, h);
|
| 52 |
-
return h;
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
std::string ToString() const {
|
| 56 |
-
std::ostringstream ostr;
|
| 57 |
-
ostr << "OrtMemoryInfo:["
|
| 58 |
-
<< "name:" << name
|
| 59 |
-
<< " id:" << id
|
| 60 |
-
<< " OrtMemType:" << mem_type
|
| 61 |
-
<< " OrtAllocatorType:" << alloc_type
|
| 62 |
-
<< " " << device.ToString()
|
| 63 |
-
<< "]";
|
| 64 |
-
return ostr.str();
|
| 65 |
-
}
|
| 66 |
-
};
|
| 67 |
-
|
| 68 |
-
// Required by hash tables
|
| 69 |
-
inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
|
| 70 |
-
return left.mem_type == other.mem_type &&
|
| 71 |
-
left.alloc_type == other.alloc_type &&
|
| 72 |
-
left.id == other.id &&
|
| 73 |
-
strcmp(left.name, other.name) == 0;
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); }
|
| 77 |
-
|
| 78 |
-
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info);
|
| 79 |
-
|
| 80 |
-
namespace std {
|
| 81 |
-
template<>
|
| 82 |
-
struct hash<OrtMemoryInfo> {
|
| 83 |
-
size_t operator()(const OrtMemoryInfo& i) const {
|
| 84 |
-
return i.Hash();
|
| 85 |
-
}
|
| 86 |
-
};
|
| 87 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/provider_options.h
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string>
|
| 7 |
-
#include <unordered_map>
|
| 8 |
-
#include <vector>
|
| 9 |
-
|
| 10 |
-
namespace onnxruntime {
|
| 11 |
-
|
| 12 |
-
// data types for execution provider options
|
| 13 |
-
|
| 14 |
-
using ProviderOptions = std::unordered_map<std::string, std::string>;
|
| 15 |
-
using ProviderOptionsVector = std::vector<ProviderOptions>;
|
| 16 |
-
using ProviderOptionsMap = std::unordered_map<std::string, ProviderOptions>;
|
| 17 |
-
|
| 18 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/provider_options_utils.h
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <algorithm>
|
| 7 |
-
#include <functional>
|
| 8 |
-
#include <type_traits>
|
| 9 |
-
#include <unordered_map>
|
| 10 |
-
#include <vector>
|
| 11 |
-
|
| 12 |
-
#include "core/common/common.h"
|
| 13 |
-
#include "core/common/parse_string.h"
|
| 14 |
-
#include "core/framework/provider_options.h"
|
| 15 |
-
|
| 16 |
-
namespace onnxruntime {
|
| 17 |
-
|
| 18 |
-
template <typename TEnum>
|
| 19 |
-
using EnumNameMapping = std::vector<std::pair<TEnum, std::string>>;
|
| 20 |
-
|
| 21 |
-
/**
|
| 22 |
-
* Given a mapping and an enumeration value, gets the corresponding name.
|
| 23 |
-
*/
|
| 24 |
-
template <typename TEnum>
|
| 25 |
-
Status EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value, std::string& name) {
|
| 26 |
-
const auto it = std::find_if(
|
| 27 |
-
mapping.begin(), mapping.end(),
|
| 28 |
-
[&value](const std::pair<TEnum, std::string>& entry) {
|
| 29 |
-
return entry.first == value;
|
| 30 |
-
});
|
| 31 |
-
ORT_RETURN_IF(
|
| 32 |
-
it == mapping.end(),
|
| 33 |
-
"Failed to map enum value to name: ", static_cast<typename std::underlying_type<TEnum>::type>(value));
|
| 34 |
-
name = it->second;
|
| 35 |
-
return Status::OK();
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
template <typename TEnum>
|
| 39 |
-
std::string EnumToName(const EnumNameMapping<TEnum>& mapping, TEnum value) {
|
| 40 |
-
std::string name;
|
| 41 |
-
ORT_THROW_IF_ERROR(EnumToName(mapping, value, name));
|
| 42 |
-
return name;
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
/**
|
| 46 |
-
* Given a mapping and a name, gets the corresponding enumeration value.
|
| 47 |
-
*/
|
| 48 |
-
template <typename TEnum>
|
| 49 |
-
Status NameToEnum(
|
| 50 |
-
const EnumNameMapping<TEnum>& mapping, const std::string& name, TEnum& value) {
|
| 51 |
-
const auto it = std::find_if(
|
| 52 |
-
mapping.begin(), mapping.end(),
|
| 53 |
-
[&name](const std::pair<TEnum, std::string>& entry) {
|
| 54 |
-
return entry.second == name;
|
| 55 |
-
});
|
| 56 |
-
ORT_RETURN_IF(
|
| 57 |
-
it == mapping.end(),
|
| 58 |
-
"Failed to map enum name to value: ", name);
|
| 59 |
-
value = it->first;
|
| 60 |
-
return Status::OK();
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
template <typename TEnum>
|
| 64 |
-
TEnum NameToEnum(const EnumNameMapping<TEnum>& mapping, const std::string& name) {
|
| 65 |
-
TEnum value;
|
| 66 |
-
ORT_THROW_IF_ERROR(NameToEnum(mapping, name, value));
|
| 67 |
-
return value;
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
class ProviderOptionsParser {
|
| 71 |
-
public:
|
| 72 |
-
/**
|
| 73 |
-
* Adds a parser for a particular provider option value.
|
| 74 |
-
*
|
| 75 |
-
* @param name The provider option name.
|
| 76 |
-
* @param value_parser An object that parses the option value.
|
| 77 |
-
* It should be callable with the following signature and return
|
| 78 |
-
* whether the parsing was successful:
|
| 79 |
-
* Status value_parser(const std::string&)
|
| 80 |
-
*
|
| 81 |
-
* @return The current ProviderOptionsParser instance.
|
| 82 |
-
*/
|
| 83 |
-
template <typename ValueParserType>
|
| 84 |
-
ProviderOptionsParser& AddValueParser(
|
| 85 |
-
const std::string& name, ValueParserType value_parser) {
|
| 86 |
-
ORT_ENFORCE(
|
| 87 |
-
value_parsers_.emplace(name, ValueParser{value_parser}).second,
|
| 88 |
-
"Provider option \"", name, "\" already has a value parser.");
|
| 89 |
-
return *this;
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
/**
|
| 93 |
-
* Adds a parser for a particular provider option value which converts a
|
| 94 |
-
* value to the right type and assigns it to the given reference.
|
| 95 |
-
*
|
| 96 |
-
* IMPORTANT: This function stores a reference to the destination variable.
|
| 97 |
-
* The caller must ensure that the reference is valid when Parse() is called!
|
| 98 |
-
*
|
| 99 |
-
* @param name The provider option name.
|
| 100 |
-
* @param dest The destination variable reference.
|
| 101 |
-
*
|
| 102 |
-
* @return The current ProviderOptionsParser instance.
|
| 103 |
-
*/
|
| 104 |
-
template <typename ValueType>
|
| 105 |
-
ProviderOptionsParser& AddAssignmentToReference(
|
| 106 |
-
const std::string& name, ValueType& dest) {
|
| 107 |
-
return AddValueParser(
|
| 108 |
-
name,
|
| 109 |
-
[&dest](const std::string& value_str) -> Status {
|
| 110 |
-
return ParseStringWithClassicLocale(value_str, dest);
|
| 111 |
-
});
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
/**
|
| 115 |
-
* Adds a parser for a particular provider option value which maps an
|
| 116 |
-
* enumeration name to a value and assigns it to the given reference.
|
| 117 |
-
*
|
| 118 |
-
* IMPORTANT: This function stores references to the mapping and destination
|
| 119 |
-
* variables. The caller must ensure that the references are valid when
|
| 120 |
-
* Parse() is called!
|
| 121 |
-
*
|
| 122 |
-
* @param name The provider option name.
|
| 123 |
-
* @param mapping The enumeration value to name mapping.
|
| 124 |
-
* @param dest The destination variable reference.
|
| 125 |
-
*
|
| 126 |
-
* @return The current ProviderOptionsParser instance.
|
| 127 |
-
*/
|
| 128 |
-
template <typename EnumType>
|
| 129 |
-
ProviderOptionsParser& AddAssignmentToEnumReference(
|
| 130 |
-
const std::string& name, const EnumNameMapping<EnumType>& mapping, EnumType& dest) {
|
| 131 |
-
return AddValueParser(
|
| 132 |
-
name,
|
| 133 |
-
[&mapping, &dest](const std::string& value_str) -> Status {
|
| 134 |
-
return NameToEnum(mapping, value_str, dest);
|
| 135 |
-
});
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
/**
|
| 139 |
-
* Parses the given provider options.
|
| 140 |
-
*/
|
| 141 |
-
Status Parse(const ProviderOptions& options) const {
|
| 142 |
-
for (const auto& option : options) {
|
| 143 |
-
const auto& name = option.first;
|
| 144 |
-
const auto& value_str = option.second;
|
| 145 |
-
const auto value_parser_it = value_parsers_.find(name);
|
| 146 |
-
ORT_RETURN_IF(
|
| 147 |
-
value_parser_it == value_parsers_.end(),
|
| 148 |
-
"Unknown provider option: \"", name, "\".");
|
| 149 |
-
|
| 150 |
-
const auto parse_status = value_parser_it->second(value_str);
|
| 151 |
-
ORT_RETURN_IF_NOT(
|
| 152 |
-
parse_status.IsOK(),
|
| 153 |
-
"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage());
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
return Status::OK();
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
private:
|
| 160 |
-
using ValueParser = std::function<Status(const std::string&)>;
|
| 161 |
-
std::unordered_map<std::string, ValueParser> value_parsers_;
|
| 162 |
-
};
|
| 163 |
-
|
| 164 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/provider_shutdown.h
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
namespace onnxruntime {
|
| 7 |
-
void UnloadSharedProviders();
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arm64-v8a/include/onnxruntime/core/framework/run_options.h
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
// Copyright (c) Microsoft Corporation. All rights reserved.
|
| 2 |
-
// Licensed under the MIT License.
|
| 3 |
-
|
| 4 |
-
#pragma once
|
| 5 |
-
|
| 6 |
-
#include <string>
|
| 7 |
-
#include <atomic>
|
| 8 |
-
#include "core/session/onnxruntime_c_api.h"
|
| 9 |
-
#include "core/framework/config_options.h"
|
| 10 |
-
|
| 11 |
-
/**
|
| 12 |
-
* Configuration information for a Run call.
|
| 13 |
-
*/
|
| 14 |
-
struct OrtRunOptions {
|
| 15 |
-
/// Log severity. See https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/common/logging/severity.h
|
| 16 |
-
/// Default = -1 (use the log severity from the InferenceSession that the Run is for).
|
| 17 |
-
int run_log_severity_level = -1;
|
| 18 |
-
int run_log_verbosity_level = 0; ///< VLOG level if debug build and run_log_severity_level is 0 (VERBOSE).
|
| 19 |
-
std::string run_tag; ///< A tag for the Run() calls using this.
|
| 20 |
-
|
| 21 |
-
// Set to 'true' to ensure the termination of all the outstanding Run() calls
|
| 22 |
-
// that use this OrtRunOptions instance. Some of the outstanding Run() calls may
|
| 23 |
-
// be forced to terminate with an error status.
|
| 24 |
-
bool terminate = false;
|
| 25 |
-
|
| 26 |
-
// Set to 'true' to run only the nodes from feeds to required fetches.
|
| 27 |
-
// So it is possible that only some of the nodes are executed.
|
| 28 |
-
bool only_execute_path_to_fetches = false;
|
| 29 |
-
|
| 30 |
-
#ifdef ENABLE_TRAINING
|
| 31 |
-
// Used by onnxruntime::training::TrainingSession. This class is now deprecated.
|
| 32 |
-
// Delete training_mode when TrainingSession is deleted.
|
| 33 |
-
// Set to 'true' to run in training mode.
|
| 34 |
-
bool training_mode = true;
|
| 35 |
-
#endif
|
| 36 |
-
|
| 37 |
-
// Stores the configurations for this run
|
| 38 |
-
// To add an configuration to this specific run, call OrtApis::AddRunConfigEntry
|
| 39 |
-
// The configuration keys and value formats are defined in
|
| 40 |
-
// /include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
|
| 41 |
-
onnxruntime::ConfigOptions config_options;
|
| 42 |
-
|
| 43 |
-
OrtRunOptions() = default;
|
| 44 |
-
~OrtRunOptions() = default;
|
| 45 |
-
};
|
| 46 |
-
|
| 47 |
-
namespace onnxruntime {
|
| 48 |
-
using RunOptions = OrtRunOptions;
|
| 49 |
-
} // namespace onnxruntime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|