| |
| |
|
|
| #pragma once |
|
|
| #include <string_view> |
|
|
| #include "core/framework/op_kernel.h" |
|
|
| namespace onnxruntime { |
|
|
| using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>; |
| using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>; |
|
|
| class IKernelTypeStrResolver; |
|
|
| |
| |
| |
| class KernelRegistry { |
| public: |
| KernelRegistry() = default; |
|
|
| |
| Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator); |
|
|
| Status Register(KernelCreateInfo&& create_info); |
|
|
| |
| |
|
|
| |
| Status TryFindKernel(const Node& node, ProviderType exec_provider, |
| const IKernelTypeStrResolver& kernel_type_str_resolver, |
| const KernelCreateInfo** out) const; |
|
|
| static bool HasImplementationOf(const KernelRegistry& r, const Node& node, |
| ProviderType exec_provider, |
| const IKernelTypeStrResolver& kernel_type_str_resolver) { |
| const KernelCreateInfo* info; |
| Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); |
| return st.IsOK(); |
| } |
|
|
| #if !defined(ORT_MINIMAL_BUILD) |
| |
| Status TryFindKernel(const std::string& op_name, const std::string& domain, const int& version, |
| const std::unordered_map<std::string, MLDataType>& type_constraints, |
| ProviderType exec_provider, const KernelCreateInfo** out) const; |
| #endif |
|
|
| bool IsEmpty() const { return kernel_creator_fn_map_.empty(); } |
|
|
| #ifdef onnxruntime_PYBIND_EXPORT_OPSCHEMA |
| |
| const KernelCreateMap& GetKernelCreateMap() const { |
| return kernel_creator_fn_map_; |
| } |
| #endif |
|
|
| private: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| static bool VerifyKernelDef(const Node& node, |
| const KernelDef& kernel_def, |
| const IKernelTypeStrResolver& kernel_type_str_resolver, |
| std::string& error_str); |
|
|
| static std::string GetMapKey(std::string_view op_name, std::string_view domain, std::string_view provider) { |
| std::string key(op_name); |
| |
| key.append(1, ' ').append(domain.empty() ? kOnnxDomainAlias : domain).append(1, ' ').append(provider); |
| return key; |
| } |
|
|
| static std::string GetMapKey(const KernelDef& kernel_def) { |
| return GetMapKey(kernel_def.OpName(), kernel_def.Domain(), kernel_def.Provider()); |
| } |
| |
| |
| KernelCreateMap kernel_creator_fn_map_; |
| }; |
| } |
|
|