File size: 13,389 Bytes
c206440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#ifndef SHARED_PROVIDER
#include <memory>
#include <unordered_map>
#include <unordered_set>

#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/data_transfer.h"
#include "core/framework/tensor.h"

namespace onnxruntime {
class GraphViewer;
struct ComputeCapability;
class KernelRegistry;
struct KernelCreateInfo;
class Node;
}  // namespace onnxruntime
#else
#include <memory>
#endif

#include "core/common/basic_types.h"
#include "core/common/profiler_common.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/func_api.h"
#include "core/framework/provider_options.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

namespace onnxruntime {

/**
   Logical device representation.
*/

// if we are export the fused function to dll, the function will still in the same binary as onnxruntime
// use std function to give execution provider some chance to capture some state.
using CreateFunctionStateFunc = std::function<int(ComputeContext*, FunctionState*)>;
using ComputeFunc = std::function<Status(FunctionState, const OrtApi*, OrtKernelContext*)>;
using DestroyFunctionStateFunc = std::function<void(FunctionState)>;

struct NodeComputeInfo {
  CreateFunctionStateFunc create_state_func;
  ComputeFunc compute_func;
  DestroyFunctionStateFunc release_state_func;
};

enum class DataLayout {
  NCHW,
  NHWC,
  NCHWC,
};

class IExecutionProvider {
 protected:
  IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
      : type_{type} {
    if (use_metadef_id_creator) {
      metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
    }
  }

 public:
  virtual ~IExecutionProvider() = default;

  /**
     Get all IAllocators for <*this> execution provider.
  */
  const std::vector<AllocatorPtr>& GetAllocators() const {
    return allocator_list_;
  }

  /**
   * Get an allocator with specified device id and MemType. Return nullptr if it doesn't exist
   */
  virtual AllocatorPtr GetAllocator(OrtMemType mem_type) const;

  /**
   * Returns a data transfer object that implements methods to copy to and
   * from this device.
   * If no copy is required for the successful operation of this provider,
   * return a nullptr.
   */
  virtual std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const {
    return nullptr;
  }

  /**
   * Interface for performing kernel lookup within kernel registries.
   * Abstracts away lower-level details about kernel registries and kernel matching.
   */
  class IKernelLookup {
   public:
    /**
     * Given `node`, try to find a matching kernel for this EP.
     * The return value is non-null if and only if a matching kernel was found.
     */
    virtual const KernelCreateInfo* LookUpKernel(const Node& node) const = 0;

   protected:
    ~IKernelLookup() = default;
  };

  /**
     Get execution provider's capability for the specified <graph>.
     Return a bunch of IndexedSubGraphs <*this> execution provider can run if
     the sub-graph contains only one node or can fuse to run if the sub-graph
     contains more than one node. The node indexes contained in sub-graphs may
     have overlap, and it's ONNXRuntime's responsibility to do the partition
     and decide whether a node will be assigned to <*this> execution provider.
     For kernels registered in a kernel registry, `kernel_lookup` must be used
     to find a matching kernel for this EP.
  */
  virtual std::vector<std::unique_ptr<ComputeCapability>>
  GetCapability(const onnxruntime::GraphViewer& graph_viewer,
                const IKernelLookup& kernel_lookup) const;

  /**
     Get kernel registry per execution provider type.
     The KernelRegistry share pointer returned is shared across sessions.

     NOTE: this approach was taken to achieve the following goals,
     1. The execution provider type based kernel registry should be shared
     across sessions.
     Only one copy of this kind of kernel registry exists in ONNXRuntime
     with multiple sessions/models.
     2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime
     framework/session code.
     3. onnxruntime (framework/session) does not depend on any specific
     execution provider lib.
  */
  virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const { return nullptr; }

  /**
     Get the device id of current execution provider
  */
  virtual int GetDeviceId() const { return 0; };

  /**
     Get execution provider's configuration options.
   */
  virtual ProviderOptions GetProviderOptions() const { return {}; }

  /**
     Returns an opaque handle whose exact type varies based on the provider
     and is interpreted accordingly by the corresponding kernel implementation.
     For Direct3D operator kernels, this may return an IUnknown supporting
     QueryInterface to ID3D12GraphicsCommandList1.
  */
  virtual const void* GetExecutionHandle() const noexcept {
    return nullptr;
  }

  /**
     @return type of the execution provider; should match that set in the node
     through the SetExecutionProvider API. Example valid return values are:
     kCpuExecutionProvider, kCudaExecutionProvider
  */
  const std::string& Type() const { return type_; }

  /**
     Blocks until the device has completed all preceding requested tasks.
     Currently this is primarily used by the IOBinding object to ensure that all
     inputs have been copied to the device before execution begins.
  */
  virtual common::Status Sync() const { return Status::OK(); }

  /**
     Called when InferenceSession::Run started
     NOTE that due to async execution in provider, the actual work of previous
     Run may not be finished on device This function should be regarded as the
     point after which a new Run would start to submit commands from CPU
  */
  virtual common::Status OnRunStart() { return Status::OK(); }

  /**
     Called when InferenceSession::Run ended
     NOTE that due to async execution in provider, the actual work of this Run
     may not be finished on device This function should be regarded as the point
     that all commands of current Run has been submmited by CPU
  */
  virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }

  /**
     Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
     the provider. Currently only CUDA execution provider supports it.
   */
  virtual bool IsGraphCaptureEnabled() const { return false; }

  /**
     Indicate whether the graph has been captured and instantiated. Currently
     only CUDA execution provider supports it.
   */
  virtual bool IsGraphCaptured() const { return false; }

  /**
     Run the instantiated graph. Currently only CUDA execution provider supports
     it.
   */
  virtual common::Status ReplayGraph() { return Status::OK(); }

  /**
     Called when session creation is complete
     This provides an opportunity for execution providers to optionally synchronize and
     clean up its temporary resources to reduce memory and ensure the first run is fast.
  */
  virtual common::Status OnSessionInitializationEnd() { return Status::OK(); }

  void InsertAllocator(AllocatorPtr allocator);
  void ReplaceAllocator(AllocatorPtr allocator);

  struct FusedNodeAndGraph {
    const std::reference_wrapper<onnxruntime::Node> fused_node;
    // GraphViewer that filters the full graph to the nodes that are covered by 'node'
    const std::reference_wrapper<GraphViewer> filtered_graph;
  };

  // Fusion approach that is suppported
  // !!! The "Function" FusionStyle is deprecated.
  // !!! If your EP is using this fusion style, please migrate it to "FilteredGraphViewer" style.
  enum class FusionStyle {
    // The node fusion will create an onnxruntime::Function based Node that contains a completely new Graph instance
    // in the Node body. The original nodes and initializers are copied to the new Graph instance in Function::Body().
    // A GraphProto can be produced from the Node body.
    Function,

    // The node fusion will create a new Node that defines the inputs and outputs using the IndexedSubGraph
    // that GetCapability returned. The Node will not be onnxruntime::Function based so will have no Body().
    // Instead a GraphViewer that filters the full Graph to the fused Nodes will be created.
    // This is significantly cheaper as it doesn't incur the cost of creating a new Graph instance,
    // and can be supported in a minimal build.
    FilteredGraphViewer
  };

  virtual FusionStyle GetFusionStyle() const {
    // All the ORT build in EP has migrate to FilteredGraphViewer style.
    // For newer EPs, please avoid use Function style as it is deprecated.
    return FusionStyle::FilteredGraphViewer;
  }

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
  /**
  Given a collection of fused Nodes and the respective GraphViewer instance for the nodes that were fused,
  return create_state/compute/release_state func for each node.
  @remarks This is now the default interface when execution provider wants to compile nodes
           for both minimal build and complete ort build.

           Do NOT cache the GraphViewer in FusedNodeAndGraph.filtered_graph in any of the NodeComputeInfo functions
           as it is only valid for the duration of the call to Compile.
  */
  virtual common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
                                 std::vector<NodeComputeInfo>& node_compute_funcs);

#endif

  void SetLogger(const logging::Logger* logger) {
    logger_ = logger;
  }

  const logging::Logger* GetLogger() const {
    return logger_;
  }

  /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
   The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
   @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
   @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
                          This is created using the model path if available,
                          or the model input names and the output names from all nodes in the main graph.
   @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
            compiled kernels, so the name must be unique and deterministic across models and sessions.
            NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
                  virtual, and ModelMetadefIdGenerator but be defined in the header as well.
   */
  virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;

  /**
     Register allocators for EP, potentially re-using existing allocators for a device from allocator_manager.
     If the EP implements this it should generally delay creating any allocators until this is called.
  */
  virtual void RegisterAllocator(AllocatorManager& /*allocator_manager*/);

  virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
    return {};
  }

  virtual DataLayout GetPreferredLayout() const {
    // NCHW is the default ONNX standard data layout. So default to it.
    // EPs which prefer a different layout should override to return their preferred layout.
    return DataLayout::NCHW;
  }

  virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/) const {}

  /** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
   */
  virtual bool ConcurrentRunSupported() const { return true; }

  /**
   * Return the tuning context which holds all TunableOp state.
   */
  virtual ITuningContext* GetTuningContext() const {
    return nullptr;
  }

 private:
  const std::string type_;

  // allocator lookup is done by combining the device id and OrtMemType.
  // there's also an implicit connection to the underlying OrtDevice involved that is dependent on the EP.
  // e.g. for a CPU based EP, 'default' memory is a CPU device, and for a GPU based EP 'default' memory is a
  // GPU device.
  using AllocatorMap = std::unordered_map<int, AllocatorPtr>;
  AllocatorMap allocators_;

  // It will be set when this object is registered to a session
  const logging::Logger* logger_ = nullptr;
  // convenience list of the allocators so GetAllocatorList doesn't have to build a new vector each time
  // contains the same instances as allocators_
  std::vector<AllocatorPtr> allocator_list_;

  // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
  // multiple sessions.
  class ModelMetadefIdGenerator {
   public:
    int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);

   private:
    std::unordered_map<HashValue, HashValue> main_graph_hash_;  // map graph instance hash to model contents hash
    std::unordered_map<HashValue, int> model_metadef_id_;       // current unique id for model
  };

  std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};
}  // namespace onnxruntime