|
|
{% if add_header %} |
|
|
|
|
|
|
|
|
import ctypes |
|
|
from collections import OrderedDict |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
import tensorrt as trt |
|
|
|
|
|
from tensorrt_llm._common import default_trtnet |
|
|
from tensorrt_llm._utils import str_dtype_to_trt |
|
|
from tensorrt_llm.functional import Tensor, _create_tensor |
|
|
from tensorrt_llm.module import Module |
|
|
|
|
|
TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm' |
|
|
|
|
|
def _load_triton_plugin_lib(): |
|
|
triton_plugin_dir = Path(__file__).parent.absolute() |
|
|
plugin_lib = "[[ plugin_lib_path ]]" |
|
|
handle = ctypes.CDLL(plugin_lib, mode=ctypes.RTLD_GLOBAL) |
|
|
if handle is None: |
|
|
raise ImportError('TensorRT-LLM Triton Plugin is unavailable') |
|
|
handle.initLibNvInferPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
|
|
handle.initLibNvInferPlugins.restype = ctypes.c_bool |
|
|
assert handle.initLibNvInferPlugins( |
|
|
None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8')) |
|
|
|
|
|
_load_triton_plugin_lib() |
|
|
|
|
|
|
|
|
{% endif %} |
|
|
|
|
|
def [[ kernel_name ]]([[ arg_list ]]): |
|
|
''' |
|
|
Inputs: |
|
|
{% for arg in metadata.get_params() -%} |
|
|
- [[arg.name]]: [[arg.dtype.dtype.to('np')]] |
|
|
{% endfor %} |
|
|
{% for arg in metadata.get_inputs() -%} |
|
|
- [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]> |
|
|
{% endfor %} |
|
|
Outputs: |
|
|
{% for arg in metadata.get_outputs() -%} |
|
|
- [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]> |
|
|
{% endfor -%} |
|
|
''' |
|
|
plg_creator = trt.get_plugin_registry().get_plugin_creator( |
|
|
'[[ plugin_name ]]', '[[ kernel_version ]]', TRT_LLM_PLUGIN_NAMESPACE) |
|
|
assert plg_creator is not None |
|
|
|
|
|
pfc = trt.PluginFieldCollection([ |
|
|
{% for arg in params -%} |
|
|
{ |
|
|
trt.PluginField("[[arg.name]]", np.array([ [[ arg.name ]] ], np.[[ arg.dtype.dtype.to('np') ]]), |
|
|
trt.PluginFieldType.[[ arg.dtype.dtype.to('trt_plugin_py') ]]), |
|
|
{% endfor %} |
|
|
]) |
|
|
|
|
|
plugin = plg_creator.create_plugin("[[ plugin_name ]]", pfc) |
|
|
|
|
|
plug_inputs = [ [[ input_list ]] ] |
|
|
layer = default_trtnet().add_plugin_v2(plug_inputs, plugin) |
|
|
|
|
|
return [ |
|
|
{% for id in range(num_outputs) %} |
|
|
_create_tensor(layer.get_output([[ id ]]), layer), |
|
|
{% endfor %} |
|
|
] |
|
|
|