G0-VLA / g0plus_dockerfile /docker-assets /data /TensorRT-10.13.0.35 /python /include /impl /NvInferPythonPlugin.h
| /* | |
| * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| * SPDX-License-Identifier: Apache-2.0 | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| //! | |
| //! \file NvInferPythonPlugin.h | |
| //! | |
| //! This file contains definitions for supporting the `tensorrt.plugin` Python module | |
| //! | |
| //! \warning None of the defintions here are part of the TensorRT C++ API and may not follow semantic versioning rules. | |
| //! TensorRT clients must not utilize them directly. | |
| //! | |
| namespace nvinfer1 | |
| { | |
| //! \enum PluginArgType | |
| //! \brief Numeric type of an extra kernel input argument in an AOT Python plugin | |
| enum class PluginArgType : int32_t | |
| { | |
| //! Integer argument | |
| kINT = 0, | |
| }; | |
| //! \enum PluginArgDataType | |
| //! \brief Data type of an extra kernel input argument in an AOT Python plugin | |
| enum class PluginArgDataType : int32_t | |
| { | |
| //! 8-bit signed integer | |
| kINT8 = 0, | |
| //! 16-bit signed integer | |
| kINT16 = 1, | |
| //! 32-bit signed integer | |
| kINT32 = 2, | |
| }; | |
| //! \class ISymExpr | |
| //! \brief Generic interface for a scalar symbolic expression implementable by a Python plugin / TensorRT Python backend | |
| class ISymExpr | |
| { | |
| public: | |
| //! \brief Get the type of the symbolic expression | |
| virtual PluginArgType getType() const noexcept = 0; | |
| //! \brief Get the data type of the symbolic expression | |
| virtual PluginArgDataType getDataType() const noexcept = 0; | |
| //! \brief Underlying symbolic expression | |
| virtual void* getExpr() noexcept = 0; | |
| }; | |
| //! Impl class for ISymExprs | |
| class ISymExprsImpl | |
| { | |
| public: | |
| virtual ISymExpr* getSymExpr(int32_t index) const noexcept = 0; | |
| virtual bool setSymExpr(int32_t index, ISymExpr* symExpr) noexcept = 0; | |
| virtual int32_t getNbSymExprs() const noexcept = 0; | |
| virtual bool setNbSymExprs(int32_t count) noexcept = 0; | |
| virtual ~ISymExprsImpl() noexcept = default; | |
| }; | |
| //! \class ISymExprs | |
| //! \brief Allows for a sequence of symbolic expressions to be communicated to the TensorRT backend | |
| //! \note Clients must not implement this class. | |
| //! \see ISymExpr | |
| class ISymExprs | |
| { | |
| public: | |
| //! \brief Get the symbolic expression at the given index | |
| //! \return A pointer to the symbolic expression or nullptr if the index is out of range | |
| ISymExpr* getSymExpr(int32_t index) const noexcept | |
| { | |
| return mImpl->getSymExpr(index); | |
| } | |
| //! \brief Set the symbolic expression at the given index | |
| //! \return true if the index is in range and the symbolic expression was set successfully, false otherwise | |
| bool setSymExpr(int32_t index, ISymExpr* symExpr) noexcept | |
| { | |
| return mImpl->setSymExpr(index, symExpr); | |
| } | |
| //! \brief Get the number of symbolic expressions | |
| int32_t getNbSymExprs() const noexcept | |
| { | |
| return mImpl->getNbSymExprs(); | |
| } | |
| //! \brief Set the number of symbolic expressions | |
| //! \return true if the number of symbolic expressions was set successfully, false otherwise | |
| bool setNbSymExprs(int32_t count) noexcept | |
| { | |
| return mImpl->setNbSymExprs(count); | |
| } | |
| protected: | |
| ISymExprsImpl* mImpl{nullptr}; | |
| virtual ~ISymExprs() noexcept = default; | |
| }; | |
| //! \enum QuickPluginCreationRequest | |
| //! \brief Communicates preference when a quickly deployable plugin is to be added to the network | |
| enum class QuickPluginCreationRequest : int32_t | |
| { | |
| //! No preference specified | |
| kUNKNOWN = 0, | |
| //! JIT plugin is preferred | |
| kPREFER_JIT = 1, | |
| //! AOT plugin is preferred | |
| kPREFER_AOT = 2, | |
| //! JIT plugin must be used. TensorRT should fail if a JIT implementation cannot be found. | |
| kSTRICT_JIT = 3, | |
| //! AOT plugin must be used. TensorRT should fail if an AOT implementation cannot be found. | |
| kSTRICT_AOT = 4, | |
| }; | |
| //! Impl class for IKernelLaunchParams | |
| class IKernelLaunchParamsImpl | |
| { | |
| public: | |
| virtual ISymExpr* getGridX() noexcept = 0; | |
| virtual bool setGridX(ISymExpr* gridX) noexcept = 0; | |
| virtual ISymExpr* getGridY() noexcept = 0; | |
| virtual bool setGridY(ISymExpr* gridY) noexcept = 0; | |
| virtual ISymExpr* getGridZ() noexcept = 0; | |
| virtual bool setGridZ(ISymExpr* gridZ) noexcept = 0; | |
| virtual ISymExpr* getBlockX() noexcept = 0; | |
| virtual bool setBlockX(ISymExpr* blockX) noexcept = 0; | |
| virtual ISymExpr* getBlockY() noexcept = 0; | |
| virtual bool setBlockY(ISymExpr* blockY) noexcept = 0; | |
| virtual ISymExpr* getBlockZ() noexcept = 0; | |
| virtual bool setBlockZ(ISymExpr* blockZ) noexcept = 0; | |
| virtual ISymExpr* getSharedMem() noexcept = 0; | |
| virtual bool setSharedMem(ISymExpr* sharedMem) noexcept = 0; | |
| virtual ~IKernelLaunchParamsImpl() noexcept = default; | |
| }; | |
| //! \class IKernelLaunchParams | |
| //! \brief Allows for kernel launch parameters to be communicated to the TensorRT backend | |
| //! \note Clients must not implement this class. | |
| class IKernelLaunchParams | |
| { | |
| public: | |
| //! Get the X dimension of the grid | |
| ISymExpr* getGridX() noexcept | |
| { | |
| return mImpl->getGridX(); | |
| } | |
| //! \brief Set the X dimension of the grid | |
| //! \return true if the grid's X dimension was set successfully, false otherwise | |
| bool setGridX(ISymExpr* gridX) noexcept | |
| { | |
| return mImpl->setGridX(gridX); | |
| } | |
| //! Get the Y dimension of the grid | |
| ISymExpr* getGridY() noexcept | |
| { | |
| return mImpl->getGridY(); | |
| } | |
| //! \brief Set the Y dimension of the grid | |
| //! \return true if the grid's Y dimension was set successfully, false otherwise | |
| bool setGridY(ISymExpr* gridY) noexcept | |
| { | |
| return mImpl->setGridY(gridY); | |
| } | |
| //! Get the Z dimension of the grid | |
| ISymExpr* getGridZ() noexcept | |
| { | |
| return mImpl->getGridZ(); | |
| } | |
| //! \brief Set the Z dimension of the grid | |
| //! \return true if the grid's Z dimension was set successfully, false otherwise | |
| bool setGridZ(ISymExpr* gridZ) noexcept | |
| { | |
| return mImpl->setGridZ(gridZ); | |
| } | |
| //! \brief Get the X dimension of each thread block | |
| ISymExpr* getBlockX() noexcept | |
| { | |
| return mImpl->getBlockX(); | |
| } | |
| //! \brief Set the X dimension of each thread block | |
| //! \return true if each thread block's X dimension was set successfully, false otherwise | |
| bool setBlockX(ISymExpr* blockX) noexcept | |
| { | |
| return mImpl->setBlockX(blockX); | |
| } | |
| //! \brief Get the Y dimension of each thread block | |
| ISymExpr* getBlockY() noexcept | |
| { | |
| return mImpl->getBlockY(); | |
| } | |
| //! \brief Set the Y dimension of each thread block | |
| //! \return true if each thread block's Y dimension was set successfully, false otherwise | |
| bool setBlockY(ISymExpr* blockY) noexcept | |
| { | |
| return mImpl->setBlockY(blockY); | |
| } | |
| //! \brief Get the Z dimension of each thread block | |
| ISymExpr* getBlockZ() noexcept | |
| { | |
| return mImpl->getBlockZ(); | |
| } | |
| //! \brief Set the Z dimension of each thread block | |
| //! \return true if each thread block's Z dimension was set successfully, false otherwise | |
| bool setBlockZ(ISymExpr* blockZ) noexcept | |
| { | |
| return mImpl->setBlockZ(blockZ); | |
| } | |
| //! \brief Get the dynamic shared-memory per thread block in bytes | |
| ISymExpr* getSharedMem() noexcept | |
| { | |
| return mImpl->getSharedMem(); | |
| } | |
| //! \brief Set the dynamic shared-memory per thread block in bytes | |
| //! \return true if the dynamic shared-memory per thread block was set successfully, false otherwise | |
| bool setSharedMem(ISymExpr* sharedMem) noexcept | |
| { | |
| return mImpl->setSharedMem(sharedMem); | |
| } | |
| protected: | |
| IKernelLaunchParamsImpl* mImpl{nullptr}; | |
| virtual ~IKernelLaunchParams() noexcept = default; | |
| }; | |
| namespace v_1_0 | |
| { | |
| class IPluginV3QuickCore : public IPluginCapability | |
| { | |
| public: | |
| InterfaceInfo getInterfaceInfo() const noexcept override | |
| { | |
| return InterfaceInfo{"PLUGIN_V3QUICK_CORE", 1, 0}; | |
| } | |
| virtual AsciiChar const* getPluginName() const noexcept = 0; | |
| virtual AsciiChar const* getPluginVersion() const noexcept = 0; | |
| virtual AsciiChar const* getPluginNamespace() const noexcept = 0; | |
| }; | |
| class IPluginV3QuickBuild : public IPluginCapability | |
| { | |
| public: | |
| InterfaceInfo getInterfaceInfo() const noexcept override | |
| { | |
| return InterfaceInfo{"PLUGIN_V3QUICK_BUILD", 1, 0}; | |
| } | |
| //! | |
| //! \brief Provide the data types of the plugin outputs if the input tensors have the data types provided. | |
| //! | |
| //! \param outputTypes Pre-allocated array to which the output data types should be written. | |
| //! \param nbOutputs The number of output tensors. This matches the value returned from getNbOutputs(). | |
| //! \param inputTypes The input data types. | |
| //! \param inputRanks Ranks of the input tensors | |
| //! \param nbInputs The number of input tensors. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getOutputDataTypes(DataType* outputTypes, int32_t nbOutputs, DataType const* inputTypes, | |
| int32_t const* inputRanks, int32_t nbInputs) const noexcept = 0; | |
| //! | |
| //! \brief Provide expressions for computing dimensions of the output tensors from dimensions of the input tensors. | |
| //! | |
| //! \param inputs Expressions for dimensions of the input tensors | |
| //! \param nbInputs The number of input tensors | |
| //! \param shapeInputs Expressions for values of the shape tensor inputs | |
| //! \param nbShapeInputs The number of shape tensor inputs | |
| //! \param outputs Pre-allocated array to which the output dimensions must be written | |
| //! \param exprBuilder Object for generating new dimension expressions | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getOutputShapes(DimsExprs const* inputs, int32_t nbInputs, DimsExprs const* shapeInputs, | |
| int32_t nbShapeInputs, DimsExprs* outputs, int32_t nbOutputs, IExprBuilder& exprBuilder) noexcept = 0; | |
| //! | |
| //! \brief Configure the plugin. Behaves similarly to `IPluginV3OneBuild::configurePlugin()` | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t configurePlugin(DynamicPluginTensorDesc const* in, int32_t nbInputs, | |
| DynamicPluginTensorDesc const* out, int32_t nbOutputs) noexcept = 0; | |
| //! | |
| //! \brief Get number of format combinations supported by the plugin for the I/O characteristics indicated by | |
| //! `inOut`. | |
| //! | |
| virtual int32_t getNbSupportedFormatCombinations( | |
| DynamicPluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept = 0; | |
| //! | |
| //! \brief Write all format combinations supported by the plugin for the I/O characteristics indicated by `inOut` to | |
| //! `supportedCombinations`. It is guaranteed to have sufficient memory allocated for (nbInputs + nbOutputs) * | |
| //! getNbSupportedFormatCombinations() `PluginTensorDesc`s. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getSupportedFormatCombinations(DynamicPluginTensorDesc const* inOut, int32_t nbInputs, | |
| int32_t nbOutputs, PluginTensorDesc* supportedCombinations, int32_t nbFormatCombinations) noexcept = 0; | |
| //! | |
| //! \brief Get the number of outputs from the plugin. | |
| //! | |
| virtual int32_t getNbOutputs() const noexcept = 0; | |
| //! | |
| //! \brief Communicates to TensorRT that the output at the specified output index is aliased to the input at the | |
| //! returned index. Behaves similary to `v_2_0::IPluginV3OneBuild.getAliasedInput()`. | |
| //! | |
| virtual int32_t getAliasedInput(int32_t outputIndex) noexcept | |
| { | |
| return -1; | |
| } | |
| //! | |
| //! \brief Query for any custom tactics that the plugin intends to use specific to the I/O characteristics indicated | |
| //! by the immediately preceding call to `configurePlugin()`. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getValidTactics(int32_t* tactics, int32_t nbTactics) noexcept | |
| { | |
| return 0; | |
| } | |
| //! | |
| //! \brief Query for number of custom tactics related to the `getValidTactics()` call. | |
| //! | |
| virtual int32_t getNbTactics() noexcept | |
| { | |
| return 0; | |
| } | |
| //! | |
| //! \brief Called to query the suffix to use for the timing cache ID. May be called anytime after plugin creation. | |
| //! | |
| virtual char const* getTimingCacheID() noexcept | |
| { | |
| return nullptr; | |
| } | |
| //! | |
| //! \brief Query for a string representing the configuration of the plugin. May be called anytime after | |
| //! plugin creation. | |
| //! | |
| virtual char const* getMetadataString() noexcept | |
| { | |
| return nullptr; | |
| } | |
| }; | |
| class IPluginV3QuickAOTBuild : public IPluginV3QuickBuild | |
| { | |
| public: | |
| InterfaceInfo getInterfaceInfo() const noexcept override | |
| { | |
| return InterfaceInfo{"PLUGIN_V3QUICKAOT_BUILD", 1, 0}; | |
| } | |
| //! \brief Get the launch parameters for the kernel to be used for the specified input and output types/formats and | |
| //! any corresponding custom tactics. | |
| //! If custom tactics are being advertised by the plugin, the corresponding tactic is the one specified by | |
| //! the immediately preceding call to setTactic(). | |
| //! | |
| //! \param inputs Expressions for dimensions of the input tensors | |
| //! \param inOut The input and output tensors' attributes | |
| //! \param nbInputs The number of input tensors | |
| //! \param nbOutputs The number of output tensors | |
| //! \param launchParams Interface which allows the specification of kernel launch parameters as symbolic expressions | |
| //! of the input dimensions | |
| //! \param extraArgs Interface which allows the specification of any scalar arguments to be | |
| //! passed to the kernel, as symbolic expressions of the input dimensions | |
| //! \param exprBuilder Object for generating new symbolic expressions | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getLaunchParams(DimsExprs const* inputs, DynamicPluginTensorDesc const* inOut, int32_t nbInputs, | |
| int32_t nbOutputs, IKernelLaunchParams* launchParams, ISymExprs* extraArgs, | |
| IExprBuilder& exprBuilder) noexcept = 0; | |
| //! | |
| //! \brief Get the compiled form for the kernel to be used for the specified input and output types/formats and any | |
| //! corresponding custom tactics. | |
| //! If custom tactics are being advertised by the plugin, the corresponding tactic is the one specified by | |
| //! the immediately preceding call to setTactic(). | |
| //! | |
| //! \param in The input tensors' attributes that are used for configuration. | |
| //! \param nbInputs Number of input tensors. | |
| //! \param out The output tensors' attributes that are used for configuration. | |
| //! \param nbOutputs Number of output tensors. | |
| //! \param kernelName The name for the kernel. | |
| //! \param compiledKernel Compiled form of the kernel. | |
| //! \param compiledKernelSize The size of the compiled kernel. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t getKernel(PluginTensorDesc const* in, int32_t nbInputs, PluginTensorDesc const* out, | |
| int32_t nbOutputs, const char** kernelName, char** compiledKernel, int32_t* compiledKernelSize) noexcept = 0; | |
| //! | |
| //! \brief Set the tactic to be used in the subsequent call to enqueue(). Behaves similar to | |
| //! IPluginV3OneRuntime::setTactic() | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t setTactic(int32_t tactic) noexcept | |
| { | |
| return 0; | |
| } | |
| }; | |
| class IPluginV3QuickRuntime : public IPluginCapability | |
| { | |
| public: | |
| InterfaceInfo getInterfaceInfo() const noexcept override | |
| { | |
| return InterfaceInfo{"PLUGIN_V3QUICK_RUNTIME", 1, 0}; | |
| } | |
| //! | |
| //! \brief Set the tactic to be used in the subsequent call to enqueue(). Behaves similar to | |
| //! `IPluginV3OneRuntime::setTactic()`. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t setTactic(int32_t tactic) noexcept | |
| { | |
| return 0; | |
| } | |
| //! | |
| //! \brief Execute the plugin. | |
| //! | |
| //! \param inputDesc how to interpret the memory for the input tensors. | |
| //! \param outputDesc how to interpret the memory for the output tensors. | |
| //! \param inputs The memory for the input tensors. | |
| //! \param inputStrides Strides for input tensors. | |
| //! \param outputStrides Strides for output tensors. | |
| //! \param outputs The memory for the output tensors. | |
| //! \param nbInputs Number of input tensors. | |
| //! \param nbOutputs Number of output tensors. | |
| //! \param stream The stream in which to execute the kernels. | |
| //! | |
| //! \return 0 for success, else non-zero | |
| //! | |
| virtual int32_t enqueue(PluginTensorDesc const* inputDesc, PluginTensorDesc const* outputDesc, | |
| void const* const* inputs, void* const* outputs, Dims const* inputStrides, Dims const* outputStrides, | |
| int32_t nbInputs, int32_t nbOutputs, cudaStream_t stream) noexcept = 0; | |
| //! | |
| //! \brief Get the plugin fields which should be serialized. | |
| //! | |
| virtual PluginFieldCollection const* getFieldsToSerialize() noexcept = 0; | |
| }; | |
| class IPluginCreatorV3Quick : public IPluginCreatorInterface | |
| { | |
| public: | |
| InterfaceInfo getInterfaceInfo() const noexcept override | |
| { | |
| return InterfaceInfo{"PLUGIN CREATOR_V3QUICK", 1, 0}; | |
| } | |
| //! | |
| //! \brief Return a plugin object. Return nullptr in case of error. | |
| //! | |
| //! \param name A NULL-terminated name string of length 1024 or less, including the NULL terminator. | |
| //! \param namespace A NULL-terminated name string of length 1024 or less, including the NULL terminator. | |
| //! \param fc A pointer to a collection of fields needed for constructing the plugin. | |
| //! \param phase The TensorRT phase in which the plugin is being created | |
| //! \param quickPluginCreationRequest Whether a JIT or AOT plugin should be created | |
| //! | |
| virtual IPluginV3* createPlugin(AsciiChar const* name, AsciiChar const* nspace, PluginFieldCollection const* fc, | |
| TensorRTPhase phase, QuickPluginCreationRequest quickPluginCreationRequest) noexcept = 0; | |
| //! | |
| //! \brief Return a list of fields that need to be passed to createPlugin() when creating a plugin for use in the | |
| //! TensorRT build phase. | |
| //! | |
| virtual PluginFieldCollection const* getFieldNames() noexcept = 0; | |
| virtual AsciiChar const* getPluginName() const noexcept = 0; | |
| virtual AsciiChar const* getPluginVersion() const noexcept = 0; | |
| virtual AsciiChar const* getPluginNamespace() const noexcept = 0; | |
| IPluginCreatorV3Quick() = default; | |
| virtual ~IPluginCreatorV3Quick() = default; | |
| protected: | |
| IPluginCreatorV3Quick(IPluginCreatorV3Quick const&) = default; | |
| IPluginCreatorV3Quick(IPluginCreatorV3Quick&&) = default; | |
| IPluginCreatorV3Quick& operator=(IPluginCreatorV3Quick const&) & = default; | |
| IPluginCreatorV3Quick& operator=(IPluginCreatorV3Quick&&) & = default; | |
| }; | |
| } // namespace v_1_0 | |
| //! | |
| //! \class IPluginV3QuickCore | |
| //! | |
| //! \brief Provides core capability (`IPluginCapability::kCORE`) for quickly-deployable TRT plugins | |
| //! | |
| //! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part | |
| //! of the public TensorRT C++ API. Users must not inherit from this class. | |
| //! | |
| using IPluginV3QuickCore = v_1_0::IPluginV3QuickCore; | |
| //! | |
| //! \class IPluginV3QuickBuild | |
| //! | |
| //! \brief Provides build capability (`IPluginCapability::kBUILD`) for quickly-deployable TRT plugins | |
| //! | |
| //! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part | |
| //! of the public TensorRT C++ API. Users must not inherit from this class. | |
| //! | |
| using IPluginV3QuickBuild = v_1_0::IPluginV3QuickBuild; | |
| //! | |
| //! \class IPluginV3QuickAOTBuild | |
| //! | |
| //! \brief Provides additional build capabilities for AOT quickly-deployable TRT plugins. Descends from | |
| //! IPluginV3QuickBuild. | |
| //! | |
| //! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part | |
| //! of the public TensorRT C++ API. Users must not inherit from this class. | |
| //! | |
| using IPluginV3QuickAOTBuild = v_1_0::IPluginV3QuickAOTBuild; | |
| //! | |
| //! \class IPluginV3QuickRuntime | |
| //! | |
| //! \brief Provides runtime capability (`IPluginCapability::kRUNTIME`) for JIT quickly-deployable TRT plugins | |
| //! | |
| //! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part | |
| //! of the public TensorRT C++ API. Users must not inherit from this class. | |
| //! | |
| using IPluginV3QuickRuntime = v_1_0::IPluginV3QuickRuntime; | |
| //! | |
| //! \class IPluginCreatorV3Quick | |
| //! | |
| //! \warning This class is strictly for the purpose of supporting quickly-deployable TRT Python plugins and is not part | |
| //! of the public TensorRT C++ API. Users must not inherit from this class. | |
| //! | |
| using IPluginCreatorV3Quick = v_1_0::IPluginCreatorV3Quick; | |
| } // namespace nvinfer1 | |