| |
| |
|
|
| #pragma once |
|
|
| #include "boost/mp11.hpp" |
|
|
| |
| |
| #include "core/framework/prepacked_weights_container.h" |
|
|
| #ifndef SHARED_PROVIDER |
| #include <functional> |
| #include "core/common/exceptions.h" |
| #include "core/common/logging/logging.h" |
| #include "core/common/status.h" |
| #include "core/framework/execution_provider.h" |
| #include "core/framework/kernel_def_builder.h" |
| #include "core/framework/ort_value.h" |
| #include "core/framework/op_kernel_info.h" |
| #include "core/framework/op_node_proto_helper.h" |
| #include "core/framework/tensor.h" |
| #include "core/framework/sparse_tensor.h" |
| #include "core/graph/constants.h" |
| #include "core/graph/graph_viewer.h" |
| #if !defined(ORT_MINIMAL_BUILD) |
| #include "onnx/defs/schema.h" |
| #else |
| #include "onnx/defs/data_type_utils.h" |
| #endif |
| #include "onnx/onnx_pb.h" |
| #include "onnx/onnx-operators_pb.h" |
| #include "core/common/gsl.h" |
| namespace onnxruntime { |
| class OpKernelContext; |
| } |
| #endif |
|
|
| namespace onnxruntime { |
|
|
| std::unique_ptr<OpKernelInfo> CopyOpKernelInfo(const OpKernelInfo& info); |
|
|
| class OpKernel { |
| public: |
| using DoneCallback = std::function<void()>; |
|
|
| explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(CopyOpKernelInfo(info)) {} |
| virtual ~OpKernel() = default; |
|
|
| const onnxruntime::Node& Node() const; |
| const onnxruntime::KernelDef& KernelDef() const; |
|
|
| [[nodiscard]] virtual Status Compute(_Inout_ OpKernelContext* context) const = 0; |
|
|
| [[nodiscard]] virtual bool IsAsync() const { |
| |
| return false; |
| } |
|
|
| [[nodiscard]] virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const { |
| ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| virtual Status |
| PrePack(const Tensor& , int , AllocatorPtr , |
| bool& is_packed, PrePackedWeights* ) { |
| is_packed = false; |
| return Status::OK(); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| virtual Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& , |
| int , |
| bool& used_shared_buffers) { |
| used_shared_buffers = false; |
| return Status::OK(); |
| } |
|
|
| const OrtMemoryInfo& Allocator(OrtMemType mem_type) const; |
| const OpKernelInfo& Info() const { |
| return *op_kernel_info_; |
| } |
|
|
| private: |
| ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); |
| std::unique_ptr<OpKernelInfo> op_kernel_info_; |
| }; |
| class FuncManager; |
| using KernelCreateFn = std::function<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>; |
| using KernelCreatePtrFn = std::add_pointer<Status(FuncManager& func_mgr, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out)>::type; |
|
|
| struct KernelCreateInfo { |
| std::unique_ptr<KernelDef> kernel_def; |
| KernelCreateFn kernel_create_func; |
| Status status; |
|
|
| KernelCreateInfo(std::unique_ptr<KernelDef> definition, |
| KernelCreateFn create_func) |
| : kernel_def(std::move(definition)), |
| kernel_create_func(create_func) {} |
|
|
| KernelCreateInfo(KernelCreateInfo&& other) noexcept |
| : kernel_def(std::move(other.kernel_def)), |
| kernel_create_func(std::move(other.kernel_create_func)) {} |
|
|
| KernelCreateInfo() = default; |
| }; |
|
|
| |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
|
|
| namespace ml { |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
| } |
|
|
| namespace contrib { |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
| } |
|
|
| namespace contrib { |
| namespace cuda { |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
| } |
| } |
|
|
| namespace contrib { |
| namespace rocm { |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
| } |
| } |
|
|
| namespace contrib { |
| namespace snpe { |
| template <typename T> |
| KernelCreateInfo BuildKernelCreateInfo(); |
| } |
| } |
|
|
| using BuildKernelCreateInfoFn = KernelCreateInfo (*)(); |
|
|
| |
| #define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \ |
| provider##_##name##_##domain##_ver##ver |
|
|
| #define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \ |
| ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \ |
| ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_MS_KERNEL(name, ver, builder, ...) \ |
| ONNX_OPERATOR_KERNEL_EX(name, kMSDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ |
| class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(ver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>( \ |
| [](FuncManager&, \ |
| const OpKernelInfo& info, \ |
| std::unique_ptr<OpKernel>& out) -> Status { \ |
| out = std::make_unique<__VA_ARGS__>(info); \ |
| return Status::OK(); \ |
| })); \ |
| } |
|
|
| #define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \ |
| provider##_##name##_##domain##_ver##startver##_##endver |
|
|
| #define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \ |
| ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \ |
| ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ |
| class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(startver, endver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ |
| } |
|
|
| #define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \ |
| provider##_##name##_##domain##_ver##ver##_##type |
|
|
| #define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \ |
| ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \ |
| ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \ |
| ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) |
|
|
| #define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \ |
| class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(ver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ |
| } |
|
|
| #define ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name) \ |
| provider##_##name##_##domain##_ver##ver##_##type1##_##type2 |
|
|
| #define ONNX_OPERATOR_TWO_TYPED_KERNEL_EX(name, domain, ver, type1, type2, provider, builder, ...) \ |
| class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type1, type2, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(ver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ |
| } |
|
|
| #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \ |
| provider##_##name##_##domain##_ver##startver##_##endver##_##type |
|
|
| #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \ |
| ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, \ |
| __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \ |
| ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, \ |
| __VA_ARGS__) |
|
|
| #define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \ |
| ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, \ |
| __VA_ARGS__) |
|
|
| #define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ |
| class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \ |
| type, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(startver, endver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ |
| } |
|
|
| #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name) \ |
| provider##_##name##_##domain##_ver##startver##_##endver##_##type1##_##type2 |
|
|
| #define ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_EX(name, domain, startver, endver, type1, type2, \ |
| provider, builder, ...) \ |
| class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type1, type2, name); \ |
| template <> \ |
| KernelCreateInfo \ |
| BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, \ |
| type1, type2, name)>() { \ |
| return KernelCreateInfo( \ |
| builder.SetName(#name) \ |
| .SetDomain(domain) \ |
| .SinceVersion(startver, endver) \ |
| .Provider(provider) \ |
| .Build(), \ |
| static_cast<KernelCreatePtrFn>([](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status { out = std::make_unique<__VA_ARGS__>(info); return Status::OK(); })); \ |
| } |
|
|
| template <typename... Types> |
| struct BuildKernelDefConstraintsImpl { |
| std::vector<MLDataType> operator()() const { |
| return {DataTypeImpl::GetTensorType<Types>()...}; |
| } |
| }; |
|
|
| #if !defined(DISABLE_SPARSE_TENSORS) |
| template <typename... Types> |
| struct BuildKernelDefSparseConstraintsImpl { |
| std::vector<MLDataType> operator()() const { |
| return {DataTypeImpl::GetSparseTensorType<Types>()...}; |
| } |
| }; |
| #endif |
|
|
| |
| |
| |
| template <typename... Types> |
| inline std::vector<MLDataType> BuildKernelDefConstraints() { |
| return BuildKernelDefConstraintsImpl<Types...>{}(); |
| } |
|
|
| #if !defined(DISABLE_SPARSE_TENSORS) |
| template <typename... Types> |
| inline std::vector<MLDataType> BuildKernelDefSparseConstraints() { |
| return BuildKernelDefSparseConstraintsImpl<Types...>{}(); |
| } |
| #endif |
|
|
| |
| template <typename L> |
| inline std::vector<MLDataType> BuildKernelDefConstraintsFromTypeList() { |
| return boost::mp11::mp_apply<BuildKernelDefConstraintsImpl, L>{}(); |
| } |
|
|
| #if !defined(DISABLE_SPARSE_TENSORS) |
| template <typename L> |
| inline std::vector<MLDataType> BuildKernelDefSparseConstraintsFromTypeList() { |
| return boost::mp11::mp_apply<BuildKernelDefSparseConstraintsImpl, L>{}(); |
| } |
| #endif |
|
|
| } |
|
|
| #ifndef SHARED_PROVIDER |
| #include "core/framework/op_kernel_context.h" |
| #endif |
|
|