aspctu's picture
Upload folder using huggingface_hub
5000658 verified
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Set
import numpy as np
import tensorrt as trt
import torch
from tensorrt_llm._common import _is_building
from tensorrt_llm._utils import (trt_dtype_to_np, trt_dtype_to_str,
trt_dtype_to_torch)
from tensorrt_llm.logger import logger
from .pipeline_graph import PipelineGraph
from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids,
set_trt_network, to_base_class_layer, to_subclass_layer,
to_trt_weights)
class ShapeType(Enum):
MIN = 0
OPT = 1
MAX = 2
_trt_to_type_dict = {
trt.int64: int,
trt.bool: bool,
}
def get_shape_layers(trt_network):
shape_layers = set()
for i in range(trt_network.num_layers):
layer = trt_network.get_layer(i)
if (layer.num_inputs > 0 and np.all([
layer.get_input(j).is_shape_tensor
for j in range(layer.num_inputs)
if layer.get_input(j) is not None
])) or (layer.num_outputs > 0 and np.all([
layer.get_output(j).is_shape_tensor
for j in range(layer.num_outputs)
])):
shape_layers.add(layer.name)
return shape_layers
def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids):
layers = set()
shape_tensors = set()
for layer_id in reversed(sorted_layer_ids):
layer = trt_network.get_layer(layer_id)
in_shape_network = False
if layer.name in shape_layers:
in_shape_network = True
else:
for j in range(layer.num_outputs):
output = layer.get_output(j)
if output.name in shape_tensors:
in_shape_network = True
break
if in_shape_network:
layers.add(layer.name)
for j in range(layer.num_inputs):
input = layer.get_input(j)
if input is not None:
shape_tensors.add(input.name)
return layers
def get_shape_network(trt_network,
shapes,
values,
sorted_layer_ids,
profile=None,
shape_type: ShapeType = ShapeType.OPT):
shape_layers = get_shape_layers(trt_network)
layers_in_shape_network = get_layers_in_shape_network(
trt_network, shape_layers, sorted_layer_ids)
shape_graph = PipelineGraph.create_graph()
shape_network = shape_graph.as_trt()
shape_builder = shape_network.builder
shape_profile = shape_builder.create_optimization_profile()
for i in range(trt_network.num_inputs):
input = trt_network.get_input(i)
shapes[input.name] = input.shape
new_input = shape_graph.add_input(input)
if profile is not None:
if -1 in input.shape:
shape = profile.get_shape(input.name)
shape = shape[shape_type.value]
shapes[input.name] = shape
new_input.raw_shape = shape
if input.is_shape_tensor:
shape_values = profile.get_shape_input(input.name)
value = shape_values[shape_type.value]
values[input.name] = value
shape_profile.set_shape_input(input.name, value, value, value)
output_mapping = {}
for layer_id in sorted_layer_ids:
layer = trt_network.get_layer(layer_id)
if layer.name in shape_layers:
new_layer = shape_graph.add_layer(layer)
for i in range(layer.num_outputs):
output = layer.get_output(i)
if output.dtype == trt.DataType.BOOL:
proxy_layer = shape_network.add_cast(
new_layer.as_trt().get_output(i),
trt.DataType.INT32,
)
proxy_output = proxy_layer.get_output(0)
shape_graph.register_layer(proxy_layer)
shape_graph.add_output_shape(proxy_output)
output_mapping[proxy_output.name] = (output.name,
output.dtype)
else:
shape_graph.add_output_shape(output)
elif layer.name in layers_in_shape_network:
if layer.type == trt.LayerType.CONSTANT:
shape_graph.add_input(layer.get_output(0))
else:
shape_graph.add_layer(layer)
return shape_network, shape_profile, shape_layers, output_mapping
def get_per_layer_graph(
layer,
shapes,
values,
updated_attrs=None,
is_shape_io: bool = None,
):
graph = PipelineGraph.create_graph()
network = graph.as_trt()
is_shape_layer = layer.num_inputs != 0
for i in range(layer.num_inputs):
input = layer.get_input(i)
if input is not None:
shape = shapes[input.name]
if (values.get(input.name) is not None
and not isinstance(values[input.name], torch.Tensor)):
value = values[input.name]
weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype))
weights = to_trt_weights(weights)
input_layer = network.add_constant(shape, weights)
new_input = input_layer.get_output(0)
new_input.name = input.name
graph.register_layer(input_layer)
elif graph.get_input(input.name) is None:
new_input = graph.add_input(input)
new_input.raw_shape = shapes[input.name]
is_shape_layer = False
new_layer = graph.add_layer(
layer,
updated_attrs=updated_attrs,
)
output_mapping = {}
if layer.type == trt.LayerType.SHAPE:
is_shape_layer = True
if layer.num_inputs == 0:
is_shape_layer = False
if is_shape_io is not None:
is_shape_layer = is_shape_io
for i in range(layer.num_outputs):
output = layer.get_output(i)
value = values.get(output.name)
if value is not None and isinstance(value, torch.Tensor):
is_output_shape = False
elif is_shape_layer:
is_output_shape = True
else:
is_output_shape = False
if is_output_shape:
if output.dtype == trt.DataType.BOOL:
proxy_layer = network.add_cast(
new_layer.as_trt().get_output(i),
trt.DataType.INT32,
)
proxy_output = proxy_layer.get_output(0)
graph.register_layer(proxy_layer)
output_mapping[proxy_output.name] = (output.name, output.dtype)
output = proxy_output
graph.add_output_shape(output)
else:
graph.add_output(output)
return graph, output_mapping
@_is_building
def infer_shapes(network, shapes, values, profile=None):
if network.num_outputs == 0:
return
builder = network.builder
config = builder.create_builder_config()
config.builder_optimization_level = 0
config.flags = get_builder_flags()
profile = profile or builder.create_optimization_profile()
config.add_optimization_profile(profile)
plan = builder.build_serialized_network(network, config)
if plan is None:
raise RuntimeError(
'Engine building failed when inferring shapes, please check the error log.'
)
runtime = trt.Runtime(logger.trt_logger)
engine = runtime.deserialize_cuda_engine(plan)
context = engine.create_execution_context()
for i in range(network.num_inputs):
input = network.get_input(i)
if input.is_shape_tensor:
value = values[input.name]
context.set_shape_input(engine[input.name], value)
for i in range(network.num_outputs):
output = network.get_output(i)
shape = context.get_tensor_shape(output.name)
shapes[output.name] = shape
if output.is_shape_tensor:
if shape == [0]:
values[output.name] = []
else:
if shape == []:
shape = [1]
value = torch.empty(
list(shape),
dtype=trt_dtype_to_torch(output.dtype),
device="cpu",
)
values[output.name] = value
context.set_tensor_address(output.name, value.data_ptr())
context.infer_shapes()
assert context.all_binding_shapes_specified
for i in range(network.num_outputs):
output = network.get_output(i)
if isinstance(values.get(output.name), torch.Tensor):
values[output.name] = values[output.name].tolist()
@dataclass
class ShapeInfo:
shapes: Dict[str, trt.Dims]
values: Dict[str, List[int]]
shape_layers: Set[str]
max_shapes: Dict[str, trt.Dims] = None
def set_constant_value(layer, values):
to_subclass_layer(layer)
output_name = layer.get_output(0).name
weights = layer.weights
if isinstance(weights, trt.Weights):
weights = weights.numpy()
values[output_name] = list(weights)
to_base_class_layer(layer)
def infer_per_layer_shapes(
layer: trt.ILayer,
shapes,
values,
cache=None,
is_shape_io=False,
):
if layer.type == trt.LayerType.CONSTANT:
to_subclass_layer(layer)
output_name = layer.get_output(0).name
shape = layer.shape
shapes[output_name] = shape
if is_shape_io:
set_constant_value(layer, values)
to_base_class_layer(layer)
return
elif layer.type == trt.LayerType.SHAPE:
input_name = layer.get_input(0).name
output_name = layer.get_output(0).name
shape = [*shapes[input_name]]
shapes[output_name] = trt.Dims([len(shape)])
values[output_name] = shape
return
if cache is not None:
cache_key = get_cache_key(layer, shapes, values)
if cache_key in cache:
output_shapes, output_values = cache[cache_key]
for i in range(layer.num_outputs):
output = layer.get_output(i)
shapes[output.name] = output_shapes[i]
if output_values[i] is not None:
values[output.name] = output_values[i]
return
graph, output_mapping = get_per_layer_graph(layer, shapes, values)
dtypes = [
trt_dtype_to_str(layer.get_input(i).dtype)
for i in range(layer.num_inputs)
]
layer_info = (f"type={cache_key[0]}, "
f"attrs={dict(cache_key[1])}, "
f"dtypes={dtypes}, "
f"shapes={list(cache_key[2])}, "
f"values={list(cache_key[3])}")
logger.debug(f"infer shapes for layer {layer.name} ({layer_info})")
try:
infer_shapes(graph.as_trt(), shapes, values)
except RuntimeError as e:
raise RuntimeError(
f"infer shapes failed for layer {layer.name} ({layer_info})") from e
for proxy_output, (output, dtype) in output_mapping.items():
shapes[output] = shapes[proxy_output]
del shapes[proxy_output]
if proxy_output in values:
values[output] = [
*map(_trt_to_type_dict[dtype], values[proxy_output])
]
del values[proxy_output]
if cache is not None:
logger.debug(
f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}"
)
output_shapes = []
output_values = []
for i in range(layer.num_outputs):
output = layer.get_output(i)
output_shapes.append(shapes[output.name])
output_values.append(values.get(output.name))
cache[cache_key] = (output_shapes, output_values)
def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT):
shapes = {}
values = {}
sorted_layer_ids = get_sorted_layer_ids(trt_network)
infer_shape_layers = False
shape_network, shape_profile, shape_layers, output_mapping = get_shape_network(
trt_network,
shapes,
values,
sorted_layer_ids,
profile=profile,
shape_type=shape_type)
try:
infer_shapes(shape_network, shapes, values, shape_profile)
for proxy_output, (output, dtype) in output_mapping.items():
shapes[output] = shapes[proxy_output]
values[output] = [
*map(_trt_to_type_dict[dtype], values[proxy_output])
]
del shapes[proxy_output]
del values[proxy_output]
except RuntimeError:
infer_shape_layers = True
cache = {}
for layer_id in sorted_layer_ids:
layer = trt_network.get_layer(layer_id)
is_shape_io = layer.name in shape_layers
if is_shape_io and not infer_shape_layers:
continue
set_trt_network(layer, trt_network)
infer_per_layer_shapes(layer,
shapes,
values,
cache,
is_shape_io=is_shape_io)
return ShapeInfo(shapes, values, shape_layers)