Spaces:
No application file
No application file
File size: 4,987 Bytes
f113e60 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import os
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def initialize_builder(use_fp16=False, workspace_size=(1 << 31)): # 2GB expressed using bit shift
"""
Khởi tạo và cấu hình builder cho TensorRT.
Args:
use_fp16 (bool): Sử dụng FP16 nếu có hỗ trợ và được yêu cầu.
workspace_size (int): Kích thước workspace tối đa cho builder.
Returns:
Tuple[trt.Builder, trt.BuilderConfig]: Trả về builder và cấu hình builder.
"""
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
config.set_tactic_sources(trt.TacticSource.CUBLAS_LT)
config.max_workspace_size = workspace_size # 2GB using bit shift
if builder.platform_has_fast_fp16 and use_fp16:
config.set_flag(trt.BuilderFlag.FP16)
return builder, config
def parse_onnx_model(builder, onnx_file_path):
"""
Phân tích mô hình ONNX và tạo network trong TensorRT.
Args:
builder (trt.Builder): Builder TensorRT.
onnx_file_path (str): Đường dẫn tới file mô hình ONNX.
Returns:
trt.INetworkDefinition: Trả về network TensorRT.
"""
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER )
with open(onnx_file_path, 'rb') as model:
if not parser.parse(model.read()):
print('❌ Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print("✅ Completed parsing ONNX file")
return network
def parse_onnx_model_static(builder, onnx_file_path, batch_size=2):
"""
Phân tích mô hình ONNX và tạo network trong TensorRT với kích thước batch cố định.
Args:
builder (trt.Builder): Builder TensorRT.
onnx_file_path (str): Đường dẫn tới file mô hình ONNX.
batch_size (int): Kích thước batch cố định.
Returns:
trt.INetworkDefinition: Trả về network TensorRT.
"""
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(onnx_file_path, 'rb') as model:
if not parser.parse(model.read()):
print('❌ Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print("✅ Completed parsing ONNX file")
# Thiết lập kích thước batch cố định cho tất cả các input
for i in range(network.num_inputs):
shape = list(network.get_input(i).shape)
shape[0] = batch_size
network.get_input(i).shape = shape
return network
def set_dynamic_shapes(builder, config, dynamic_shapes):
"""
Thiết lập các kích thước động cho mô hình.
Args:
builder (trt.Builder): Builder TensorRT.
network (trt.INetworkDefinition): Network TensorRT.
config (trt.BuilderConfig): Cấu hình builder.
dynamic_shapes (dict): Từ điển các kích thước động cho mô hình.
"""
if dynamic_shapes:
print(f"===> Using dynamic shapes: {str(dynamic_shapes)}")
profile = builder.create_optimization_profile()
for binding_name, dynamic_shape in dynamic_shapes.items():
min_shape, opt_shape, max_shape = dynamic_shape
profile.set_shape(binding_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
def build_and_save_engine(builder, network, config, engine_file_path):
"""
Xây dựng và lưu engine TensorRT.
Args:
builder (trt.Builder): Builder TensorRT.
network (trt.INetworkDefinition): Network TensorRT.
config (trt.BuilderConfig): Cấu hình builder.
engine_file_path (str): Đường dẫn để lưu engine.
"""
if os.path.isfile(engine_file_path):
try:
os.remove(engine_file_path)
except Exception as e:
print(f"Cannot remove existing file: {engine_file_path}. Error: {e}")
print("Creating TensorRT Engine...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine:
with open(engine_file_path, "wb") as f:
f.write(serialized_engine)
print(f"===> Serialized Engine Saved at: {engine_file_path}")
else:
print("❌ Failed to build engine")
# Fix batch_size
def main_fixed():
batch_size = 1
onnx_file_path = "models/tusimple_18.onnx"
engine_file_path = "models/tusimple_18_FP16.trt"
builder, config = initialize_builder(use_fp16=True)
network = parse_onnx_model_static(builder, onnx_file_path, batch_size=batch_size)
if network:
build_and_save_engine(builder, network, config, engine_file_path)
if __name__ == "__main__":
main_fixed() |