|
|
import contextlib |
|
|
import threading |
|
|
|
|
|
try: |
|
|
from types import NoneType |
|
|
except ImportError: |
|
|
NoneType = type(None) |
|
|
from typing import ByteString, Iterable, MutableMapping |
|
|
|
|
|
import tensorrt as trt |
|
|
import torch |
|
|
|
|
|
from tensorrt_llm._utils import get_extra_attr, np_dtype_to_trt, set_extra_attr |
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.network import PluginInfo, get_plugin_info |
|
|
|
|
|
LAYER_TYPE_2_CLASS = { |
|
|
trt.LayerType.ACTIVATION: trt.IActivationLayer, |
|
|
trt.LayerType.CONCATENATION: trt.IConcatenationLayer, |
|
|
trt.LayerType.CONSTANT: trt.IConstantLayer, |
|
|
trt.LayerType.ELEMENTWISE: trt.IElementWiseLayer, |
|
|
trt.LayerType.FILL: trt.IFillLayer, |
|
|
trt.LayerType.GATHER: trt.IGatherLayer, |
|
|
trt.LayerType.MATRIX_MULTIPLY: trt.IMatrixMultiplyLayer, |
|
|
trt.LayerType.REDUCE: trt.IReduceLayer, |
|
|
trt.LayerType.SELECT: trt.ISelectLayer, |
|
|
trt.LayerType.SHUFFLE: trt.IShuffleLayer, |
|
|
trt.LayerType.SLICE: trt.ISliceLayer, |
|
|
trt.LayerType.SOFTMAX: trt.ISoftMaxLayer, |
|
|
trt.LayerType.UNARY: trt.IUnaryLayer, |
|
|
trt.LayerType.SHAPE: trt.IShapeLayer, |
|
|
trt.LayerType.ASSERTION: trt.IAssertionLayer, |
|
|
trt.LayerType.CAST: trt.ICastLayer, |
|
|
trt.LayerType.NORMALIZATION: trt.INormalizationLayer, |
|
|
trt.LayerType.IDENTITY: trt.IIdentityLayer, |
|
|
trt.LayerType.PLUGIN_V2: trt.IPluginV2Layer, |
|
|
} |
|
|
|
|
|
|
|
|
def to_subclass_layer(trt_layer): |
|
|
trt_layer.__class__ = LAYER_TYPE_2_CLASS[trt_layer.type] |
|
|
|
|
|
|
|
|
def to_base_class_layer(trt_layer): |
|
|
trt_layer.__class__ = trt.ILayer |
|
|
|
|
|
|
|
|
def to_trt_weights(ndarray): |
|
|
weight = trt.Weights( |
|
|
np_dtype_to_trt(ndarray.dtype), |
|
|
ndarray.ctypes.data, |
|
|
ndarray.size, |
|
|
) |
|
|
|
|
|
set_extra_attr(weight, "numpy", ndarray) |
|
|
return weight |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def silent_trt_logger(): |
|
|
min_severity = logger.trt_logger.min_severity |
|
|
logger.trt_logger.min_severity = trt.Logger.ERROR |
|
|
yield |
|
|
logger.trt_logger.min_severity = min_severity |
|
|
|
|
|
|
|
|
def compare_tensor(trt_tensor, new_trt_tensor): |
|
|
assert trt_tensor.name == new_trt_tensor.name |
|
|
assert trt_tensor.dtype == new_trt_tensor.dtype |
|
|
assert tuple(trt_tensor.shape) == tuple(new_trt_tensor.shape) |
|
|
assert trt_tensor.broadcast_across_batch == new_trt_tensor.broadcast_across_batch |
|
|
assert trt_tensor.location == new_trt_tensor.location |
|
|
assert trt_tensor.is_network_input == new_trt_tensor.is_network_input |
|
|
assert trt_tensor.is_network_output == new_trt_tensor.is_network_output |
|
|
assert trt_tensor.dynamic_range == new_trt_tensor.dynamic_range |
|
|
assert trt_tensor.is_shape_tensor == new_trt_tensor.is_shape_tensor |
|
|
assert trt_tensor.is_execution_tensor == new_trt_tensor.is_execution_tensor |
|
|
assert trt_tensor.allowed_formats == new_trt_tensor.allowed_formats |
|
|
|
|
|
|
|
|
def compare_network(trt_network, new_trt_network): |
|
|
assert trt_network.num_inputs == new_trt_network.num_inputs |
|
|
for i in range(trt_network.num_inputs): |
|
|
input = trt_network.get_input(i) |
|
|
new_input = new_trt_network.get_input(i) |
|
|
compare_tensor(input, new_input) |
|
|
assert trt_network.num_outputs == new_trt_network.num_outputs |
|
|
for i in range(trt_network.num_outputs): |
|
|
output = trt_network.get_output(i) |
|
|
new_output = new_trt_network.get_output(i) |
|
|
compare_tensor(output, new_output) |
|
|
assert trt_network.num_layers == new_trt_network.num_layers |
|
|
for index, new_index in zip(get_sorted_layer_ids(trt_network), |
|
|
get_sorted_layer_ids(new_trt_network)): |
|
|
layer = trt_network.get_layer(index) |
|
|
new_layer = new_trt_network.get_layer(new_index) |
|
|
assert layer.name == new_layer.name |
|
|
assert layer.type == new_layer.type |
|
|
assert layer.precision_is_set == new_layer.precision_is_set |
|
|
assert layer.precision == new_layer.precision |
|
|
assert layer.num_inputs == new_layer.num_inputs |
|
|
for j in range(layer.num_inputs): |
|
|
input = layer.get_input(j) |
|
|
new_input = new_layer.get_input(j) |
|
|
if input is None: |
|
|
assert new_input is None |
|
|
else: |
|
|
assert new_input is not None |
|
|
compare_tensor(input, new_input) |
|
|
assert layer.num_outputs == new_layer.num_outputs |
|
|
for j in range(layer.num_outputs): |
|
|
output = layer.get_output(j) |
|
|
new_output = new_layer.get_output(j) |
|
|
compare_tensor(output, new_output) |
|
|
assert layer.output_type_is_set(j) == new_layer.output_type_is_set( |
|
|
j) |
|
|
if layer.output_type_is_set(j): |
|
|
assert layer.get_output_type(j) == new_layer.get_output_type(j) |
|
|
|
|
|
|
|
|
def get_sorted_layer_ids(trt_network): |
|
|
inputs = set() |
|
|
for i in range(trt_network.num_inputs): |
|
|
inputs.add(trt_network.get_input(i).name) |
|
|
layer_ids = [*range(trt_network.num_layers)] |
|
|
sorted_layer_ids = [] |
|
|
walked_tensors = set(inputs) |
|
|
while len(layer_ids) > 0: |
|
|
layer_id = layer_ids.pop(0) |
|
|
layer = trt_network.get_layer(layer_id) |
|
|
no_dependencies = True |
|
|
for j in range(layer.num_inputs): |
|
|
input = layer.get_input(j) |
|
|
if input is None: |
|
|
continue |
|
|
if input.name in walked_tensors: |
|
|
continue |
|
|
else: |
|
|
no_dependencies = False |
|
|
break |
|
|
if no_dependencies: |
|
|
sorted_layer_ids.append(layer_id) |
|
|
for j in range(layer.num_outputs): |
|
|
output = layer.get_output(j) |
|
|
if output is None: |
|
|
continue |
|
|
walked_tensors.add(output.name) |
|
|
else: |
|
|
layer_ids.append(layer_id) |
|
|
assert len(sorted_layer_ids) == trt_network.num_layers |
|
|
return sorted_layer_ids |
|
|
|
|
|
|
|
|
def to_tuple(values): |
|
|
if isinstance(values, (int, float, str, bool, NoneType, ByteString)): |
|
|
return values |
|
|
elif isinstance(values, (trt.Dims, trt.Permutation)): |
|
|
if values.__len__() < 0: |
|
|
return None |
|
|
else: |
|
|
return tuple(values) |
|
|
elif isinstance(values, Iterable): |
|
|
return tuple(to_tuple(v) for v in values) |
|
|
elif isinstance(values, MutableMapping): |
|
|
return tuple((k, to_tuple(v)) for k, v in values.items()) |
|
|
else: |
|
|
return values |
|
|
|
|
|
|
|
|
_base_layer_attr_names = set(dir(trt.ILayer)) |
|
|
|
|
|
|
|
|
def get_cache_key(layer, shapes, values, dtypes=None, updated_attrs=None): |
|
|
updated_attrs = updated_attrs or {} |
|
|
layer_type = layer.type |
|
|
to_subclass_layer(layer) |
|
|
attr_names = set(dir(layer)) - _base_layer_attr_names |
|
|
if layer_type == trt.LayerType.CONSTANT: |
|
|
attr_names.remove("weights") |
|
|
elif layer_type == trt.LayerType.SHUFFLE: |
|
|
if layer.num_inputs >= 2: |
|
|
attr_names.remove("reshape_dims") |
|
|
elif layer_type == trt.LayerType.SLICE: |
|
|
if layer.num_inputs >= 2 and layer.get_input(1) is not None: |
|
|
attr_names.remove("start") |
|
|
if layer.num_inputs >= 3 and layer.get_input(2) is not None: |
|
|
attr_names.remove("shape") |
|
|
if layer.num_inputs >= 4 and layer.get_input(3) is not None: |
|
|
attr_names.remove("stride") |
|
|
elif layer_type == trt.LayerType.FILL: |
|
|
attr_names.remove("is_alpha_beta_int64") |
|
|
if layer.num_inputs >= 1 and layer.get_input(0) is not None: |
|
|
attr_names.remove("shape") |
|
|
if layer.num_inputs >= 2 and layer.get_input(1) is not None: |
|
|
attr_names.remove("alpha") |
|
|
if layer.num_inputs >= 3 and layer.get_input(2) is not None: |
|
|
attr_names.remove("beta") |
|
|
if layer_type != trt.LayerType.PLUGIN_V2: |
|
|
attr_key = tuple( |
|
|
(name, to_tuple(updated_attrs.get(name) or getattr(layer, name))) |
|
|
for name in sorted(attr_names)) |
|
|
else: |
|
|
network = get_trt_network(layer) |
|
|
plugin_info = get_plugin_info(network, layer.name) |
|
|
assert plugin_info is not None, f"layer {layer.name} does not register plugin info" |
|
|
attr_key = tuple( |
|
|
(name, tuple(updated_attrs.get(name) or data)) |
|
|
for name, data in sorted(plugin_info.pfc_as_list.items())) |
|
|
to_base_class_layer(layer) |
|
|
shape_key = () |
|
|
value_key = () |
|
|
dtype_key = () |
|
|
for i in range(layer.num_inputs): |
|
|
input = layer.get_input(i) |
|
|
if input is not None: |
|
|
shape_key += (tuple(shapes[input.name]), ) |
|
|
if input.name in values: |
|
|
value = values[input.name] |
|
|
|
|
|
|
|
|
if isinstance(value, torch.Tensor): |
|
|
value = None |
|
|
else: |
|
|
value = tuple(value) |
|
|
value_key += (value, ) |
|
|
else: |
|
|
value_key += (None, ) |
|
|
if dtypes is not None: |
|
|
dtype_key += (dtypes[input.name], ) |
|
|
else: |
|
|
shape_key += (None, ) |
|
|
value_key += (None, ) |
|
|
dtype_key += (None, ) |
|
|
if dtypes is not None: |
|
|
for i in range(layer.num_outputs): |
|
|
output = layer.get_output(i) |
|
|
dtype_key += (dtypes[output.name], ) |
|
|
cache_key = (layer.type, attr_key, shape_key, value_key) |
|
|
if dtypes is not None: |
|
|
cache_key += (dtype_key, ) |
|
|
return cache_key |
|
|
|
|
|
|
|
|
def get_trt_network(layer: trt.ILayer): |
|
|
network = get_extra_attr(layer, "network") |
|
|
assert network is not None |
|
|
return network |
|
|
|
|
|
|
|
|
def set_trt_network(layer: trt.ILayer, network: trt.INetworkDefinition): |
|
|
set_extra_attr(layer, "network", network) |
|
|
|
|
|
|
|
|
def get_updated_plugin(plugin_info: PluginInfo, updated_attrs): |
|
|
fields = [] |
|
|
for field in plugin_info.pfc: |
|
|
name = field.name |
|
|
if name in updated_attrs: |
|
|
field = trt.PluginField(name, updated_attrs[name], field.type) |
|
|
else: |
|
|
field = trt.PluginField(name, plugin_info.pfc_as_ndarray[name], |
|
|
field.type) |
|
|
fields.append(field) |
|
|
pfc = trt.PluginFieldCollection(fields) |
|
|
plugin = plugin_info.plugin_creator.create_plugin(plugin_info.plugin_name, |
|
|
pfc) |
|
|
new_plugin_info = PluginInfo(plugin_info.plugin_creator, |
|
|
plugin_info.plugin_name, pfc) |
|
|
return plugin, new_plugin_info |
|
|
|
|
|
|
|
|
_builder_flags = threading.local() |
|
|
_strongly_typed = threading.local() |
|
|
|
|
|
|
|
|
def get_builder_flags(): |
|
|
return getattr(_builder_flags, 'value', 0) |
|
|
|
|
|
|
|
|
def get_strongly_typed(): |
|
|
return getattr(_strongly_typed, 'value', False) |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def current_flags(builder_flags, strongly_typed): |
|
|
previous_builder_flags = get_builder_flags() |
|
|
_builder_flags.value = builder_flags |
|
|
previous_strongly_typed = get_strongly_typed() |
|
|
_strongly_typed.value = strongly_typed |
|
|
yield |
|
|
_builder_flags.value = previous_builder_flags |
|
|
_strongly_typed.value = previous_strongly_typed |
|
|
|
|
|
|
|
|
def get_engine_information(engine_file) -> str: |
|
|
with open(engine_file, "rb") as f: |
|
|
engine_buffer = f.read() |
|
|
runtime = trt.Runtime(logger.trt_logger) |
|
|
engine = runtime.deserialize_cuda_engine(engine_buffer) |
|
|
inspector = engine.create_engine_inspector() |
|
|
return inspector.get_engine_information(trt.LayerInformationFormat.JSON) |
|
|
|
|
|
|
|
|
def print_engine_info(engine_file) -> dict: |
|
|
with open(engine_file, "rb") as f: |
|
|
engine_buffer = f.read() |
|
|
from tensorrt_llm.runtime.session import Session |
|
|
Session.from_serialized_engine(engine_buffer)._print_engine_info() |
|
|
|