File size: 2,328 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
{% if add_header %}
# //header begin//

import ctypes
from collections import OrderedDict
from pathlib import Path
from typing import List

import numpy as np
import tensorrt as trt

from tensorrt_llm._common import default_trtnet
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.functional import Tensor, _create_tensor
from tensorrt_llm.module import Module

TRT_LLM_PLUGIN_NAMESPACE = 'tensorrt_llm'

def _load_triton_plugin_lib():
    triton_plugin_dir = Path(__file__).parent.absolute()
    plugin_lib = "[[ plugin_lib_path ]]"
    handle = ctypes.CDLL(plugin_lib, mode=ctypes.RTLD_GLOBAL)
    if handle is None:
        raise ImportError('TensorRT-LLM Triton Plugin is unavailable')
    handle.initLibNvInferPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
    handle.initLibNvInferPlugins.restype = ctypes.c_bool
    assert handle.initLibNvInferPlugins(
        None, TRT_LLM_PLUGIN_NAMESPACE.encode('utf-8'))

_load_triton_plugin_lib()

# //header end//
{% endif %}

def [[ kernel_name ]]([[ arg_list ]]):
    '''
    Inputs:
    {% for arg in metadata.get_params() -%}
    - [[arg.name]]: [[arg.dtype.dtype.to('np')]]
    {% endfor %}
    {% for arg in metadata.get_inputs() -%}
    - [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]>
    {% endfor %}
    Outputs:
    {% for arg in metadata.get_outputs() -%}
    - [[arg.name]]: {% if arg.is_tensor %}tensor<{%endif%}[[arg.dtype.dtype.to('np')]]>
    {% endfor -%}
    '''
    plg_creator = trt.get_plugin_registry().get_plugin_creator(
        '[[ plugin_name ]]', '[[ kernel_version ]]', TRT_LLM_PLUGIN_NAMESPACE)
    assert plg_creator is not None

    pfc = trt.PluginFieldCollection([
        {% for arg in params -%}
        {# input is a dict[name, np_type, trt_type ] #}
        trt.PluginField("[[arg.name]]", np.array([ [[ arg.name ]] ], np.[[ arg.dtype.dtype.to('np') ]]),
                        trt.PluginFieldType.[[ arg.dtype.dtype.to('trt_plugin_py') ]]),
        {% endfor %}
    ])

    plugin = plg_creator.create_plugin("[[ plugin_name ]]", pfc)

    plug_inputs = [ [[ input_list ]] ]
    layer = default_trtnet().add_plugin_v2(plug_inputs, plugin)

    return [
        {% for id in range(num_outputs) %}
        _create_tensor(layer.get_output([[ id ]]), layer),
        {% endfor %}
    ]