import os import tensorrt as trt TRT_LOGGER = trt.Logger(trt.Logger.WARNING) def initialize_builder(use_fp16=False, workspace_size=(1 << 31)): # 2GB expressed using bit shift """ Khởi tạo và cấu hình builder cho TensorRT. Args: use_fp16 (bool): Sử dụng FP16 nếu có hỗ trợ và được yêu cầu. workspace_size (int): Kích thước workspace tối đa cho builder. Returns: Tuple[trt.Builder, trt.BuilderConfig]: Trả về builder và cấu hình builder. """ builder = trt.Builder(TRT_LOGGER) config = builder.create_builder_config() config.set_tactic_sources(trt.TacticSource.CUBLAS_LT) config.max_workspace_size = workspace_size # 2GB using bit shift if builder.platform_has_fast_fp16 and use_fp16: config.set_flag(trt.BuilderFlag.FP16) return builder, config def parse_onnx_model(builder, onnx_file_path): """ Phân tích mô hình ONNX và tạo network trong TensorRT. Args: builder (trt.Builder): Builder TensorRT. onnx_file_path (str): Đường dẫn tới file mô hình ONNX. Returns: trt.INetworkDefinition: Trả về network TensorRT. """ network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, TRT_LOGGER ) with open(onnx_file_path, 'rb') as model: if not parser.parse(model.read()): print('❌ Failed to parse the ONNX file.') for error in range(parser.num_errors): print(parser.get_error(error)) return None print("✅ Completed parsing ONNX file") return network def parse_onnx_model_static(builder, onnx_file_path, batch_size=2): """ Phân tích mô hình ONNX và tạo network trong TensorRT với kích thước batch cố định. Args: builder (trt.Builder): Builder TensorRT. onnx_file_path (str): Đường dẫn tới file mô hình ONNX. batch_size (int): Kích thước batch cố định. Returns: trt.INetworkDefinition: Trả về network TensorRT. """ network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, TRT_LOGGER) with open(onnx_file_path, 'rb') as model: if not parser.parse(model.read()): print('❌ Failed to parse the ONNX file.') for error in range(parser.num_errors): print(parser.get_error(error)) return None print("✅ Completed parsing ONNX file") # Thiết lập kích thước batch cố định cho tất cả các input for i in range(network.num_inputs): shape = list(network.get_input(i).shape) shape[0] = batch_size network.get_input(i).shape = shape return network def set_dynamic_shapes(builder, config, dynamic_shapes): """ Thiết lập các kích thước động cho mô hình. Args: builder (trt.Builder): Builder TensorRT. network (trt.INetworkDefinition): Network TensorRT. config (trt.BuilderConfig): Cấu hình builder. dynamic_shapes (dict): Từ điển các kích thước động cho mô hình. """ if dynamic_shapes: print(f"===> Using dynamic shapes: {str(dynamic_shapes)}") profile = builder.create_optimization_profile() for binding_name, dynamic_shape in dynamic_shapes.items(): min_shape, opt_shape, max_shape = dynamic_shape profile.set_shape(binding_name, min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) def build_and_save_engine(builder, network, config, engine_file_path): """ Xây dựng và lưu engine TensorRT. Args: builder (trt.Builder): Builder TensorRT. network (trt.INetworkDefinition): Network TensorRT. config (trt.BuilderConfig): Cấu hình builder. engine_file_path (str): Đường dẫn để lưu engine. """ if os.path.isfile(engine_file_path): try: os.remove(engine_file_path) except Exception as e: print(f"Cannot remove existing file: {engine_file_path}. Error: {e}") print("Creating TensorRT Engine...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine: with open(engine_file_path, "wb") as f: f.write(serialized_engine) print(f"===> Serialized Engine Saved at: {engine_file_path}") else: print("❌ Failed to build engine") # Fix batch_size def main_fixed(): batch_size = 1 onnx_file_path = "models/tusimple_18.onnx" engine_file_path = "models/tusimple_18_FP16.trt" builder, config = initialize_builder(use_fp16=True) network = parse_onnx_model_static(builder, onnx_file_path, batch_size=batch_size) if network: build_and_save_engine(builder, network, config, engine_file_path) if __name__ == "__main__": main_fixed()