FireRedASR-AED-L-TensorRT / export_encoder_tensorrt.py
yuekai's picture
Upload folder using huggingface_hub
98c8d47 verified
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2025 Nvidia Corp. (authors: Yuekai Zhang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a pre-trained FireRedASR encoder model from PyTorch to
ONNX and TensorRT.
Usage:
python3 examples/export_encoder_tensorrt.py \
--model-dir /path/to/your/model_dir \
--tensorrt-model-dir ./tensorrt_models \
--trt-engine-file-name encoder.plan
"""
import argparse
import logging
from pathlib import Path
import torch
import tensorrt as trt
from fireredasr.models.fireredasr import load_fireredasr_aed_model
def get_parser() -> argparse.ArgumentParser:
"""Get the command-line argument parser."""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help="The model directory that contains model checkpoint.",
)
parser.add_argument(
"--onnx-model-path",
type=str,
default=None,
help="If specified, we will directly use this onnx model to generate "
"the tensorrt engine",
)
parser.add_argument(
"--idim",
type=int,
default=80,
help="The input dimension of the model. This is required when "
"--onnx-model-path is specified.",
)
parser.add_argument(
"--tensorrt-model-dir",
type=str,
default="exp",
help="Directory to save the exported models.",
)
parser.add_argument(
"--trt-engine-file-name",
type=str,
default="encoder.plan",
help="The name of the TensorRT engine file.",
)
parser.add_argument(
"--opset-version",
type=int,
default=17,
help="ONNX opset version.",
)
return parser
def export_encoder_onnx(
encoder: torch.nn.Module,
filename: str,
idim: int,
opset_version: int = 17,
) -> None:
"""Export the conformer encoder model to ONNX format."""
logging.info("Exporting encoder to ONNX")
encoder.half()
# Create dummy inputs
seq_len = 400 # A typical sequence length
batch_size = 1
padded_input = torch.randn(batch_size, seq_len, idim, dtype=torch.float16)
input_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32)
# Export
torch.onnx.export(
encoder,
(padded_input, input_lengths),
filename,
opset_version=opset_version,
input_names=["padded_input", "input_lengths"],
output_names=["enc_output", "output_lengths", "src_mask"],
dynamic_axes={
"padded_input": {0: "batch_size", 1: "seq_len"},
"input_lengths": {0: "batch_size"},
"enc_output": {0: "batch_size", 1: "seq_len_out"},
"output_lengths": {0: "batch_size",},
"src_mask": {0: "batch_size", 2: "seq_len_out"},
},
)
logging.info(f"Exported encoder to {filename}")
def get_trt_kwargs_dynamic_batch(
idim: int,
min_batch_size: int = 1,
opt_batch_size: int = 4,
max_batch_size: int = 64,
):
"""Get keyword arguments for TensorRT with dynamic batch size."""
min_seq_len = 50
opt_seq_len = 400
max_seq_len = 3000
min_shape = [(min_batch_size, min_seq_len, idim), (min_batch_size,)]
opt_shape = [(opt_batch_size, opt_seq_len, idim), (opt_batch_size,)]
max_shape = [(max_batch_size, max_seq_len, idim), (max_batch_size,)]
input_names = ["padded_input", "input_lengths"]
return {
"min_shape": min_shape,
"opt_shape": opt_shape,
"max_shape": max_shape,
"input_names": input_names,
}
def convert_onnx_to_trt(
trt_model: str, trt_kwargs: dict, onnx_model: str, dtype: torch.dtype = torch.float16
) -> None:
"""Convert an ONNX model to a TensorRT engine."""
logging.info("Converting ONNX to TensorRT engine...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError(f'Failed to parse {onnx_model}')
for i, name in enumerate(trt_kwargs['input_names']):
profile.set_shape(
name,
trt_kwargs['min_shape'][i],
trt_kwargs['opt_shape'][i],
trt_kwargs['max_shape'][i]
)
config.add_optimization_profile(profile)
try:
engine_bytes = builder.build_serialized_network(network, config)
except Exception as e:
logging.error(f"TensorRT engine build failed: {e}")
return
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Successfully converted ONNX to TensorRT.")
@torch.no_grad()
def main():
"""Main function to export the model."""
parser = get_parser()
args = parser.parse_args()
tensorrt_model_dir = Path(args.tensorrt_model_dir)
tensorrt_model_dir.mkdir(parents=True, exist_ok=True)
if args.onnx_model_path:
logging.info(f"Using provided ONNX model: {args.onnx_model_path}")
if not args.idim:
raise ValueError("--idim is required when using --onnx-model-path")
idim = args.idim
encoder_onnx_file = Path(args.onnx_model_path)
if not encoder_onnx_file.is_file():
raise FileNotFoundError(f"ONNX model not found at {encoder_onnx_file}")
else:
if not args.model_dir:
raise ValueError(
"--model-dir is required if --onnx-model-path is not provided"
)
logging.info("Exporting ONNX model from PyTorch checkpoint")
model_dir = Path(args.model_dir)
model_path = model_dir / "model.pth.tar"
# Load model to get encoder
package = torch.load(model_path, map_location="cpu", weights_only=False)
model_args = package["args"]
idim = model_args.idim
# We have to load the full AED model to get the encoder with weights
model = load_fireredasr_aed_model(str(model_path))
encoder = model.encoder
encoder.eval()
# Export ONNX
encoder_onnx_file = tensorrt_model_dir / "encoder.fp16.onnx"
export_encoder_onnx(
encoder=encoder,
filename=str(encoder_onnx_file),
idim=idim,
opset_version=args.opset_version,
)
# Convert ONNX to TensorRT
trt_engine_file = tensorrt_model_dir / args.trt_engine_file_name
trt_kwargs = get_trt_kwargs_dynamic_batch(idim=idim)
convert_onnx_to_trt(
trt_model=str(trt_engine_file),
trt_kwargs=trt_kwargs,
onnx_model=str(encoder_onnx_file),
dtype=torch.float16,
)
logging.info("Done!")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()