| // Copyright (c) Microsoft Corporation. All rights reserved. | |
| // Licensed under the MIT License. | |
| namespace onnxruntime { | |
| class GraphViewer; | |
| struct ComputeCapability; | |
| class KernelRegistry; | |
| struct KernelCreateInfo; | |
| class Node; | |
| } // namespace onnxruntime | |
| namespace onnxruntime { | |
| /** | |
| Logical device representation. | |
| */ | |
| // if we are export the fused function to dll, the function will still in the same binary as onnxruntime | |
| // use std function to give execution provider some chance to capture some state. | |
| using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>; | |
| using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>; | |
| using DestroyFunctionStateFunc = std::function<void(FunctionState)>; | |
| struct NodeComputeInfo { | |
| CreateFunctionStateFunc create_state_func; | |
| ComputeFunc compute_func; | |
| DestroyFunctionStateFunc release_state_func; | |
| }; | |
| enum class DataLayout { | |
| NCHW, | |
| NHWC, | |
| NCHWC, | |
| }; | |
| class IExecutionProvider { | |
| protected: | |
| IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false) | |
| : type_{type} { | |
| if (use_metadef_id_creator) { | |
| metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>(); | |
| } | |
| } | |
| public: | |
| virtual ~IExecutionProvider() = default; | |
| /** | |
| Get all IAllocators for <*this> execution provider. | |
| */ | |
| const std::vector<AllocatorPtr>& GetAllocators() const { | |
| return allocator_list_; | |
| } | |
| /** | |
| * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist | |
| */ | |
| virtual AllocatorPtr GetAllocator(OrtMemType mem_type) const; | |
| /** | |
| * Returns a data transfer object that implements methods to copy to and | |
| * from this device. | |
| * If no copy is required for the successful operation of this provider, | |
| * return a nullptr. | |
| */ | |
| virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const { | |
| return nullptr; | |
| } | |
| /** | |
| * Interface for performing kernel lookup within kernel registries. | |
| * Abstracts away lower-level details about kernel registries and kernel matching. | |
| */ | |
| class IKernelLookup { | |
| public: | |
| /** | |
| * Given `node`, try to find a matching kernel for this EP. | |
| * The return value is non-null if and only if a matching kernel was found. | |
| */ | |
| virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0; | |
| protected: | |
| ~IKernelLookup() = default; | |
| }; | |
| /** | |
| Get execution provider's capability for the specified <graph>. | |
| Return a bunch of IndexedSubGraphs <*this> execution provider can run if | |
| the sub-graph contains only one node or can fuse to run if the sub-graph | |
| contains more than one node. The node indexes contained in sub-graphs may | |
| have overlap, and it's ONNXRuntime's responsibility to do the partition | |
| and decide whether a node will be assigned to <*this> execution provider. | |
| For kernels registered in a kernel registry, `kernel_lookup` must be used | |
| to find a matching kernel for this EP. | |
| */ | |
| virtual std::vector<std::unique_ptr<ComputeCapability>> | |
| GetCapability(const onnxruntime::GraphViewer& graph_viewer, | |
| const IKernelLookup& kernel_lookup) const; | |
| /** | |
| Get kernel registry per execution provider type. | |
| The KernelRegistry share pointer returned is shared across sessions. | |
| NOTE: this approach was taken to achieve the following goals, | |
| 1. The execution provider type based kernel registry should be shared | |
| across sessions. | |
| Only one copy of this kind of kernel registry exists in ONNXRuntime | |
| with multiple sessions/models. | |
| 2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime | |
| framework/session code. | |
| 3. onnxruntime (framework/session) does not depend on any specific | |
| execution provider lib. | |
| */ | |
| virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; } | |
| /** | |
| Get the device id of current execution provider | |
| */ | |
| virtual int GetDeviceId() const { return 0; }; | |
| /** | |
| Get execution provider's configuration options. | |
| */ | |
| virtual ProviderOptions GetProviderOptions() const { return {}; } | |
| /** | |
| Returns an opaque handle whose exact type varies based on the provider | |
| and is interpreted accordingly by the corresponding kernel implementation. | |
| For Direct3D operator kernels, this may return an IUnknown supporting | |
| QueryInterface to ID3D12GraphicsCommandList1. | |
| */ | |
| virtual const void* GetExecutionHandle() const noexcept { | |
| return nullptr; | |
| } | |
| /** | |
| @return type of the execution provider; should match that set in the node | |
| through the SetExecutionProvider API. Example valid return values are: | |
| kCpuExecutionProvider, kCudaExecutionProvider | |
| */ | |
| const std::string& Type() const { return type_; } | |
| /** | |
| Blocks until the device has completed all preceding requested tasks. | |
| Currently this is primarily used by the IOBinding object to ensure that all | |
| inputs have been copied to the device before execution begins. | |
| */ | |
| virtual common::Status Sync() const { return Status::OK(); } | |
| /** | |
| Called when InferenceSession::Run started | |
| NOTE that due to async execution in provider, the actual work of previous | |
| Run may not be finished on device This function should be regarded as the | |
| point after which a new Run would start to submit commands from CPU | |
| */ | |
| virtual common::Status OnRunStart() { return Status::OK(); } | |
| /** | |
| Called when InferenceSession::Run ended | |
| NOTE that due to async execution in provider, the actual work of this Run | |
| may not be finished on device This function should be regarded as the point | |
| that all commands of current Run has been submmited by CPU | |
| */ | |
| virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } | |
| /** | |
| Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for | |
| the provider. Currently only CUDA execution provider supports it. | |
| */ | |
| virtual bool IsGraphCaptureEnabled() const { return false; } | |
| /** | |
| Indicate whether the graph has been captured and instantiated. Currently | |
| only CUDA execution provider supports it. | |
| */ | |
| virtual bool IsGraphCaptured() const { return false; } | |
| /** | |
| Run the instantiated graph. Currently only CUDA execution provider supports | |
| it. | |
| */ | |
| virtual common::Status ReplayGraph() { return Status::OK(); } | |
| /** | |
| Called when session creation is complete | |
| This provides an opportunity for execution providers to optionally synchronize and | |
| clean up its temporary resources to reduce memory and ensure the first run is fast. | |
| */ | |
| virtual common::Status OnSessionInitializationEnd() { return Status::OK(); } | |
| void InsertAllocator(AllocatorPtr allocator); | |
| void ReplaceAllocator(AllocatorPtr allocator); | |
| struct FusedNodeAndGraph { | |
| const std::reference_wrapper<onnxruntime::Node> fused_node; | |
| // GraphViewer that filters the full graph to the nodes that are covered by 'node' | |
| const std::reference_wrapper<GraphViewer> filtered_graph; | |
| }; | |
| // Fusion approach that is suppported | |
| // !!! The "Function" FusionStyle is deprecated. | |
| // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style. | |
| enum class FusionStyle { | |
| // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance | |
| // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body(). | |
| // A GraphProto can be produced from the Node body. | |
| Function, | |
| // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph | |
| // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body(). | |
| // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created. | |
| // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance, | |
| // and can be supported in a minimal build. | |
| FilteredGraphViewer | |
| }; | |
| virtual FusionStyle GetFusionStyle() const { | |
| // All the ORT build in EP has migrate to FilteredGraphViewer style. | |
| // For newer EPs, please avoid use Function style as it is deprecated. | |
| return FusionStyle::FilteredGraphViewer; | |
| } | |
| /** | |
| Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused, | |
| return create_state/compute/release_state func for each node. | |
| @remarks This is now the default interface when execution provider wants to compile nodes | |
| for both minimal build and complete ort build. | |
| Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions | |
| as it is only valid for the duration of the call to Compile. | |
| */ | |
| virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs, | |
| std::vector<NodeComputeInfo>& node_compute_funcs); | |
| void SetLogger(const logging::Logger* logger) { | |
| logger_ = logger; | |
| } | |
| const logging::Logger* GetLogger() const { | |
| return logger_; | |
| } | |
| /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. | |
| The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. | |
| @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. | |
| @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. | |
| This is created using the model path if available, | |
| or the model input names and the output names from all nodes in the main graph. | |
| @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches | |
| compiled kernels, so the name must be unique and deterministic across models and sessions. | |
| NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and | |
| virtual, and ModelMetadefIdGenerator but be defined in the header as well. | |
| */ | |
| virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; | |
| /** | |
| Register allocators for EP, potentially re-using existing allocators for a device from allocator_manager. | |
| If the EP implements this it should generally delay creating any allocators until this is called. | |
| */ | |
| virtual void RegisterAllocator(AllocatorManager& /*allocator_manager*/); | |
| virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() { | |
| return {}; | |
| } | |
| virtual DataLayout GetPreferredLayout() const { | |
| // NCHW is the default ONNX standard data layout. So default to it. | |
| // EPs which prefer a different layout should override to return their preferred layout. | |
| return DataLayout::NCHW; | |
| } | |
| virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/) const {} | |
| /** Does the EP support concurrent calls to InferenceSession::Run to execute the model. | |
| */ | |
| virtual bool ConcurrentRunSupported() const { return true; } | |
| /** | |
| * Return the tuning context which holds all TunableOp state. | |
| */ | |
| virtual ITuningContext* GetTuningContext() const { | |
| return nullptr; | |
| } | |
| private: | |
| const std::string type_; | |
| // allocator lookup is done by combining the device id and OrtMemType. | |
| // there's also an implicit connection to the underlying OrtDevice involved that is dependent on the EP. | |
| // e.g. for a CPU based EP, 'default' memory is a CPU device, and for a GPU based EP 'default' memory is a | |
| // GPU device. | |
| using AllocatorMap = std::unordered_map<int, AllocatorPtr>; | |
| AllocatorMap allocators_; | |
| // It will be set when this object is registered to a session | |
| const logging::Logger* logger_ = nullptr; | |
| // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time | |
| // contains the same instances as allocators_ | |
| std::vector<AllocatorPtr> allocator_list_; | |
| // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across | |
| // multiple sessions. | |
| class ModelMetadefIdGenerator { | |
| public: | |
| int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash); | |
| private: | |
| std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash | |
| std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model | |
| }; | |
| std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_; | |
| }; | |
| } // namespace onnxruntime | |