File size: 11,555 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import contextlib
import threading

try:
    from types import NoneType
except ImportError:
    NoneType = type(None)
from typing import ByteString, Iterable, MutableMapping

import tensorrt as trt
import torch

from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr
from tensorrt_llm.logger import logger
from tensorrt_llm.network import PluginInfo, get_plugin_info

LAYER_TYPE_2_CLASS = {
    trt.LayerType.ACTIVATION: trt.IActivationLayer,
    trt.LayerType.CONCATENATION: trt.IConcatenationLayer,
    trt.LayerType.CONSTANT: trt.IConstantLayer,
    trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer,
    trt.LayerType.FILL: trt.IFillLayer,
    trt.LayerType.GATHER: trt.IGatherLayer,
    trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer,
    trt.LayerType.REDUCE: trt.IReduceLayer,
    trt.LayerType.SELECT: trt.ISelectLayer,
    trt.LayerType.SHUFFLE: trt.IShuffleLayer,
    trt.LayerType.SLICE: trt.ISliceLayer,
    trt.LayerType.SOFTMAX: trt.ISoftMaxLayer,
    trt.LayerType.UNARY: trt.IUnaryLayer,
    trt.LayerType.SHAPE: trt.IShapeLayer,
    trt.LayerType.ASSERTION: trt.IAssertionLayer,
    trt.LayerType.CAST: trt.ICastLayer,
    trt.LayerType.NORMALIZATION: trt.INormalizationLayer,
    trt.LayerType.IDENTITY: trt.IIdentityLayer,
    trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer,
}


def to_subclass_layer(trt_layer):
    trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type]


def to_base_class_layer(trt_layer):
    trt_layer.__class__ = trt.ILayer


def to_trt_weights(ndarray):
    weight = trt.Weights(
        np_dtype_to_trt(ndarray.dtype),
        ndarray.ctypes.data,
        ndarray.size,
    )
    # Prevent numpy array from going out of weight's lifetime scope
    set_extra_attr(weight, "numpy", ndarray)
    return weight


@contextlib.contextmanager
def silent_trt_logger():
    min_severity = logger.trt_logger.min_severity
    logger.trt_logger.min_severity = trt.Logger.ERROR
    yield
    logger.trt_logger.min_severity = min_severity


def compare_tensor(trt_tensor, new_trt_tensor):
    assert trt_tensor.name == new_trt_tensor.name
    assert trt_tensor.dtype == new_trt_tensor.dtype
    assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape)
    assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch
    assert trt_tensor.location == new_trt_tensor.location
    assert trt_tensor.is_network_input == new_trt_tensor.is_network_input
    assert trt_tensor.is_network_output == new_trt_tensor.is_network_output
    assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range
    assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor
    assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor
    assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats


def compare_network(trt_network, new_trt_network):
    assert trt_network.num_inputs == new_trt_network.num_inputs
    for i in range(trt_network.num_inputs):
        input = trt_network.get_input(i)
        new_input = new_trt_network.get_input(i)
        compare_tensor(input, new_input)
    assert trt_network.num_outputs == new_trt_network.num_outputs
    for i in range(trt_network.num_outputs):
        output = trt_network.get_output(i)
        new_output = new_trt_network.get_output(i)
        compare_tensor(output, new_output)
    assert trt_network.num_layers == new_trt_network.num_layers
    for index, new_index in zip(get_sorted_layer_ids(trt_network),
                                get_sorted_layer_ids(new_trt_network)):
        layer = trt_network.get_layer(index)
        new_layer = new_trt_network.get_layer(new_index)
        assert layer.name == new_layer.name
        assert layer.type == new_layer.type
        assert layer.precision_is_set == new_layer.precision_is_set
        assert layer.precision == new_layer.precision
        assert layer.num_inputs == new_layer.num_inputs
        for j in range(layer.num_inputs):
            input = layer.get_input(j)
            new_input = new_layer.get_input(j)
            if input is None:
                assert new_input is None
            else:
                assert new_input is not None
                compare_tensor(input, new_input)
        assert layer.num_outputs == new_layer.num_outputs
        for j in range(layer.num_outputs):
            output = layer.get_output(j)
            new_output = new_layer.get_output(j)
            compare_tensor(output, new_output)
            assert layer.output_type_is_set(j) == new_layer.output_type_is_set(
                j)
            if layer.output_type_is_set(j):
                assert layer.get_output_type(j) == new_layer.get_output_type(j)


def get_sorted_layer_ids(trt_network):
    inputs = set()
    for i in range(trt_network.num_inputs):
        inputs.add(trt_network.get_input(i).name)
    layer_ids = [*range(trt_network.num_layers)]
    sorted_layer_ids = []
    walked_tensors = set(inputs)
    while len(layer_ids) > 0:
        layer_id = layer_ids.pop(0)
        layer = trt_network.get_layer(layer_id)
        no_dependencies = True
        for j in range(layer.num_inputs):
            input = layer.get_input(j)
            if input is None:
                continue
            if input.name in walked_tensors:
                continue
            else:
                no_dependencies = False
                break
        if no_dependencies:
            sorted_layer_ids.append(layer_id)
            for j in range(layer.num_outputs):
                output = layer.get_output(j)
                if output is None:
                    continue
                walked_tensors.add(output.name)
        else:
            layer_ids.append(layer_id)
    assert len(sorted_layer_ids) == trt_network.num_layers
    return sorted_layer_ids


def to_tuple(values):
    if isinstance(values, (int, float, str, bool, NoneType, ByteString)):
        return values
    elif isinstance(values, (trt.Dims, trt.Permutation)):
        if values.__len__() < 0:
            return None
        else:
            return tuple(values)
    elif isinstance(values, Iterable):
        return tuple(to_tuple(v) for v in values)
    elif isinstance(values, MutableMapping):
        return tuple((k, to_tuple(v)) for k, v in values.items())
    else:
        return values


_base_layer_attr_names = set(dir(trt.ILayer))


def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None):
    updated_attrs = updated_attrs or {}
    layer_type = layer.type
    to_subclass_layer(layer)
    attr_names = set(dir(layer)) - _base_layer_attr_names
    if layer_type == trt.LayerType.CONSTANT:
        attr_names.remove("weights")
    elif layer_type == trt.LayerType.SHUFFLE:
        if layer.num_inputs >= 2:
            attr_names.remove("reshape_dims")
    elif layer_type == trt.LayerType.SLICE:
        if layer.num_inputs >= 2 and layer.get_input(1) is not None:
            attr_names.remove("start")
        if layer.num_inputs >= 3 and layer.get_input(2) is not None:
            attr_names.remove("shape")
        if layer.num_inputs >= 4 and layer.get_input(3) is not None:
            attr_names.remove("stride")
    elif layer_type == trt.LayerType.FILL:
        attr_names.remove("is_alpha_beta_int64")
        if layer.num_inputs >= 1 and layer.get_input(0) is not None:
            attr_names.remove("shape")
        if layer.num_inputs >= 2 and layer.get_input(1) is not None:
            attr_names.remove("alpha")
        if layer.num_inputs >= 3 and layer.get_input(2) is not None:
            attr_names.remove("beta")
    if layer_type != trt.LayerType.PLUGIN_V2:
        attr_key = tuple(
            (name, to_tuple(updated_attrs.get(name) or getattr(layer, name)))
            for name in sorted(attr_names))
    else:
        network = get_trt_network(layer)
        plugin_info = get_plugin_info(network, layer.name)
        assert plugin_info is not None, f"layer {layer.name} does not register plugin info"
        attr_key = tuple(
            (name, tuple(updated_attrs.get(name) or data))
            for name, data in sorted(plugin_info.pfc_as_list.items()))
    to_base_class_layer(layer)
    shape_key = ()
    value_key = ()
    dtype_key = ()
    for i in range(layer.num_inputs):
        input = layer.get_input(i)
        if input is not None:
            shape_key += (tuple(shapes[input.name]), )
            if input.name in values:
                value = values[input.name]
                # All torch tensors are derived from input shapes and pfc,
                # thus we ignore them in cache key
                if isinstance(value, torch.Tensor):
                    value = None
                else:
                    value = tuple(value)
                value_key += (value, )
            else:
                value_key += (None, )
            if dtypes is not None:
                dtype_key += (dtypes[input.name], )
        else:
            shape_key += (None, )
            value_key += (None, )
            dtype_key += (None, )
    if dtypes is not None:
        for i in range(layer.num_outputs):
            output = layer.get_output(i)
            dtype_key += (dtypes[output.name], )
    cache_key = (layer.type, attr_key, shape_key, value_key)
    if dtypes is not None:
        cache_key += (dtype_key, )
    return cache_key


def get_trt_network(layer: trt.ILayer):
    network = get_extra_attr(layer, "network")
    assert network is not None
    return network


def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition):
    set_extra_attr(layer, "network", network)


def get_updated_plugin(plugin_info: PluginInfo, updated_attrs):
    fields = []
    for field in plugin_info.pfc:
        name = field.name
        if name in updated_attrs:
            field = trt.PluginField(name, updated_attrs[name], field.type)
        else:
            field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name],
                                    field.type)
        fields.append(field)
    pfc = trt.PluginFieldCollection(fields)
    plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name,
                                                      pfc)
    new_plugin_info = PluginInfo(plugin_info.plugin_creator,
                                 plugin_info.plugin_name, pfc)
    return plugin, new_plugin_info


_builder_flags = threading.local()
_strongly_typed = threading.local()


def get_builder_flags():
    return getattr(_builder_flags, 'value', 0)


def get_strongly_typed():
    return getattr(_strongly_typed, 'value', False)


@contextlib.contextmanager
def current_flags(builder_flags, strongly_typed):
    previous_builder_flags = get_builder_flags()
    _builder_flags.value = builder_flags
    previous_strongly_typed = get_strongly_typed()
    _strongly_typed.value = strongly_typed
    yield
    _builder_flags.value = previous_builder_flags
    _strongly_typed.value = previous_strongly_typed


def get_engine_information(engine_file) -> str:
    with open(engine_file, "rb") as f:
        engine_buffer = f.read()
    runtime = trt.Runtime(logger.trt_logger)
    engine = runtime.deserialize_cuda_engine(engine_buffer)
    inspector = engine.create_engine_inspector()
    return inspector.get_engine_information(trt.LayerInformationFormat.JSON)


def print_engine_info(engine_file) -> dict:
    with open(engine_file, "rb") as f:
        engine_buffer = f.read()
    from tensorrt_llm.runtime.session import Session
    Session.from_serialized_engine(engine_buffer)._print_engine_info()