File size: 8,029 Bytes
98c8d47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
# Copyright 2025 Nvidia Corp. (authors: Yuekai Zhang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a pre-trained FireRedASR encoder model from PyTorch to
ONNX and TensorRT.
Usage:
python3 examples/export_encoder_tensorrt.py \
--model-dir /path/to/your/model_dir \
--tensorrt-model-dir ./tensorrt_models \
--trt-engine-file-name encoder.plan
"""
import argparse
import logging
from pathlib import Path
import torch
import tensorrt as trt
from fireredasr.models.fireredasr import load_fireredasr_aed_model
def get_parser() -> argparse.ArgumentParser:
"""Get the command-line argument parser."""
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-dir",
type=str,
default=None,
help="The model directory that contains model checkpoint.",
)
parser.add_argument(
"--onnx-model-path",
type=str,
default=None,
help="If specified, we will directly use this onnx model to generate "
"the tensorrt engine",
)
parser.add_argument(
"--idim",
type=int,
default=80,
help="The input dimension of the model. This is required when "
"--onnx-model-path is specified.",
)
parser.add_argument(
"--tensorrt-model-dir",
type=str,
default="exp",
help="Directory to save the exported models.",
)
parser.add_argument(
"--trt-engine-file-name",
type=str,
default="encoder.plan",
help="The name of the TensorRT engine file.",
)
parser.add_argument(
"--opset-version",
type=int,
default=17,
help="ONNX opset version.",
)
return parser
def export_encoder_onnx(
encoder: torch.nn.Module,
filename: str,
idim: int,
opset_version: int = 17,
) -> None:
"""Export the conformer encoder model to ONNX format."""
logging.info("Exporting encoder to ONNX")
encoder.half()
# Create dummy inputs
seq_len = 400 # A typical sequence length
batch_size = 1
padded_input = torch.randn(batch_size, seq_len, idim, dtype=torch.float16)
input_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32)
# Export
torch.onnx.export(
encoder,
(padded_input, input_lengths),
filename,
opset_version=opset_version,
input_names=["padded_input", "input_lengths"],
output_names=["enc_output", "output_lengths", "src_mask"],
dynamic_axes={
"padded_input": {0: "batch_size", 1: "seq_len"},
"input_lengths": {0: "batch_size"},
"enc_output": {0: "batch_size", 1: "seq_len_out"},
"output_lengths": {0: "batch_size",},
"src_mask": {0: "batch_size", 2: "seq_len_out"},
},
)
logging.info(f"Exported encoder to {filename}")
def get_trt_kwargs_dynamic_batch(
idim: int,
min_batch_size: int = 1,
opt_batch_size: int = 4,
max_batch_size: int = 64,
):
"""Get keyword arguments for TensorRT with dynamic batch size."""
min_seq_len = 50
opt_seq_len = 400
max_seq_len = 3000
min_shape = [(min_batch_size, min_seq_len, idim), (min_batch_size,)]
opt_shape = [(opt_batch_size, opt_seq_len, idim), (opt_batch_size,)]
max_shape = [(max_batch_size, max_seq_len, idim), (max_batch_size,)]
input_names = ["padded_input", "input_lengths"]
return {
"min_shape": min_shape,
"opt_shape": opt_shape,
"max_shape": max_shape,
"input_names": input_names,
}
def convert_onnx_to_trt(
trt_model: str, trt_kwargs: dict, onnx_model: str, dtype: torch.dtype = torch.float16
) -> None:
"""Convert an ONNX model to a TensorRT engine."""
logging.info("Converting ONNX to TensorRT engine...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
if dtype == torch.float16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
with open(onnx_model, "rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise ValueError(f'Failed to parse {onnx_model}')
for i, name in enumerate(trt_kwargs['input_names']):
profile.set_shape(
name,
trt_kwargs['min_shape'][i],
trt_kwargs['opt_shape'][i],
trt_kwargs['max_shape'][i]
)
config.add_optimization_profile(profile)
try:
engine_bytes = builder.build_serialized_network(network, config)
except Exception as e:
logging.error(f"TensorRT engine build failed: {e}")
return
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Successfully converted ONNX to TensorRT.")
@torch.no_grad()
def main():
"""Main function to export the model."""
parser = get_parser()
args = parser.parse_args()
tensorrt_model_dir = Path(args.tensorrt_model_dir)
tensorrt_model_dir.mkdir(parents=True, exist_ok=True)
if args.onnx_model_path:
logging.info(f"Using provided ONNX model: {args.onnx_model_path}")
if not args.idim:
raise ValueError("--idim is required when using --onnx-model-path")
idim = args.idim
encoder_onnx_file = Path(args.onnx_model_path)
if not encoder_onnx_file.is_file():
raise FileNotFoundError(f"ONNX model not found at {encoder_onnx_file}")
else:
if not args.model_dir:
raise ValueError(
"--model-dir is required if --onnx-model-path is not provided"
)
logging.info("Exporting ONNX model from PyTorch checkpoint")
model_dir = Path(args.model_dir)
model_path = model_dir / "model.pth.tar"
# Load model to get encoder
package = torch.load(model_path, map_location="cpu", weights_only=False)
model_args = package["args"]
idim = model_args.idim
# We have to load the full AED model to get the encoder with weights
model = load_fireredasr_aed_model(str(model_path))
encoder = model.encoder
encoder.eval()
# Export ONNX
encoder_onnx_file = tensorrt_model_dir / "encoder.fp16.onnx"
export_encoder_onnx(
encoder=encoder,
filename=str(encoder_onnx_file),
idim=idim,
opset_version=args.opset_version,
)
# Convert ONNX to TensorRT
trt_engine_file = tensorrt_model_dir / args.trt_engine_file_name
trt_kwargs = get_trt_kwargs_dynamic_batch(idim=idim)
convert_onnx_to_trt(
trt_model=str(trt_engine_file),
trt_kwargs=trt_kwargs,
onnx_model=str(encoder_onnx_file),
dtype=torch.float16,
)
logging.info("Done!")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
|