|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef TRT_[[ plugin_name ]]_PLUGIN_H |
|
|
#define TRT_[[ plugin_name ]]_PLUGIN_H |
|
|
#include "NvInferPlugin.h" |
|
|
|
|
|
extern "C" |
|
|
{ |
|
|
|
|
|
#include "[[ kernel_name ]]/[[ triton_aot_dir ]]/launcher.h" |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
#include <cassert> |
|
|
#include <set> |
|
|
#include <string> |
|
|
#include <vector> |
|
|
|
|
|
[[ plugin_common_header ]] |
|
|
|
|
|
|
|
|
namespace nvinfer1 |
|
|
{ |
|
|
namespace plugin |
|
|
{ |
|
|
|
|
|
|
|
|
class [[ plugin_name ]] : public IPluginV2DynamicExt { |
|
|
|
|
|
public: |
|
|
[[ plugin_name ]]( [[ construct_arg_list ]]); |
|
|
[[ plugin_name ]](const void* data, size_t length); |
|
|
|
|
|
~[[ plugin_name ]]() override = default; |
|
|
|
|
|
|
|
|
|
|
|
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; |
|
|
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, |
|
|
nvinfer1::IExprBuilder& exprBuilder) noexcept override; |
|
|
bool supportsFormatCombination( |
|
|
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; |
|
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, |
|
|
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; |
|
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, |
|
|
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; |
|
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, |
|
|
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; |
|
|
|
|
|
|
|
|
nvinfer1::DataType getOutputDataType( |
|
|
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; |
|
|
|
|
|
|
|
|
const char* getPluginType() const noexcept override; |
|
|
const char* getPluginVersion() const noexcept override; |
|
|
int getNbOutputs() const noexcept override; |
|
|
int initialize() noexcept override; |
|
|
void terminate() noexcept override; |
|
|
size_t getSerializationSize() const noexcept override; |
|
|
void serialize(void* buffer) const noexcept override; |
|
|
void destroy() noexcept override; |
|
|
void setPluginNamespace(const char* pluginNamespace) noexcept override; |
|
|
const char* getPluginNamespace() const noexcept override; |
|
|
|
|
|
private: |
|
|
std::string mNamespace; |
|
|
|
|
|
{% for arg in params -%} |
|
|
[[arg.dtype.dtype.to('c')]] [[arg.name]]; |
|
|
{% endfor %} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
class [[ plugin_name ]]Creator : public IPluginCreator |
|
|
{ |
|
|
public: |
|
|
[[ plugin_name ]]Creator(); |
|
|
|
|
|
const char* getPluginName() const noexcept override; |
|
|
|
|
|
const char* getPluginVersion() const noexcept override; |
|
|
|
|
|
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; |
|
|
|
|
|
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; |
|
|
|
|
|
nvinfer1::IPluginV2* deserializePlugin( |
|
|
const char* name, const void* serialData, size_t serialLength) noexcept override; |
|
|
|
|
|
void setPluginNamespace(const char* pluginNamespace) noexcept override; |
|
|
|
|
|
const char* getPluginNamespace() const noexcept override; |
|
|
|
|
|
private: |
|
|
static PluginFieldCollection mFC; |
|
|
static std::vector<PluginField> mPluginAttributes; |
|
|
std::string mNamespace; |
|
|
}; |
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
#endif |
|
|
|