yuekai commited on
Commit
98c8d47
Β·
verified Β·
1 Parent(s): 1f74c53

Upload folder using huggingface_hub

Browse files
convert_checkpoint.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import torch
7
+ from safetensors.torch import save_file
8
+
9
+ import tensorrt_llm
10
+ from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
11
+ from tensorrt_llm.models.convert_utils import weight_only_quantize_dict
12
+ from tensorrt_llm.quantization import QuantAlgo
13
+
14
+
15
+ def parse_arguments():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--model_path', type=str, required=True,
18
+ help="Path to the FireRedASR model.pth.tar checkpoint.")
19
+ parser.add_argument('--output_dir', type=str, default='tllm_checkpoint',
20
+ help='The path to save the TensorRT-LLM checkpoint')
21
+ parser.add_argument('--dtype', type=str, default='float16',
22
+ choices=['float32', 'bfloat16', 'float16'])
23
+ parser.add_argument('--logits_dtype', type=str, default='float16',
24
+ choices=['float16', 'float32'])
25
+ parser.add_argument(
26
+ '--use_weight_only',
27
+ default=False,
28
+ action="store_true",
29
+ help='Quantize weights for the various GEMMs to INT4/INT8.'
30
+ 'See --weight_only_precision to set the precision')
31
+ parser.add_argument(
32
+ '--weight_only_precision',
33
+ const='int8',
34
+ type=str,
35
+ nargs='?',
36
+ default='int8',
37
+ choices=['int8', 'int4'],
38
+ help=
39
+ 'Define the precision for the weights when using weight-only quantization.'
40
+ 'You must also use --use_weight_only for that argument to have an impact.'
41
+ )
42
+ return parser.parse_args()
43
+
44
+
45
+ def get_decoder_config(model_args, dtype: str, logits_dtype: str, quant_algo: QuantAlgo) -> dict:
46
+ return {
47
+ 'architecture': "DecoderModel",
48
+ 'dtype': dtype,
49
+ 'logits_dtype': logits_dtype,
50
+ 'num_hidden_layers': model_args.n_layers_dec,
51
+ 'num_attention_heads': model_args.n_head,
52
+ 'hidden_size': model_args.d_model,
53
+ 'norm_epsilon': 1e-5,
54
+ 'vocab_size': model_args.odim,
55
+ 'hidden_act': "gelu",
56
+ 'use_parallel_embedding': False,
57
+ 'embedding_sharding_dim': 0,
58
+ 'max_position_embeddings': model_args.pe_maxlen,
59
+ 'use_prompt_tuning': False,
60
+ 'head_size': model_args.d_model // model_args.n_head,
61
+ 'has_position_embedding': True,
62
+ 'layernorm_type': LayerNormType.LayerNorm,
63
+ 'has_attention_qkvo_bias': True,
64
+ 'has_mlp_bias': True,
65
+ 'has_model_final_layernorm': True,
66
+ 'has_embedding_layernorm': False,
67
+ 'has_embedding_scale': True, # FireRedASR scales the embedding
68
+ 'ffn_hidden_size': 4 * model_args.d_model,
69
+ 'q_scaling': 1.0,
70
+ 'layernorm_position': LayerNormPositionType.pre_layernorm,
71
+ 'relative_attention': False,
72
+ 'max_distance': 0,
73
+ 'num_buckets': 0,
74
+ 'model_type': 'whisper', # To align with Whisper decoder architecture in TRT-LLM
75
+ 'rescale_before_lm_head': False,
76
+ 'encoder_hidden_size': model_args.d_model,
77
+ 'encoder_num_heads': model_args.n_head,
78
+ 'encoder_head_size': None,
79
+ 'skip_cross_kv': False,
80
+ 'quantization': {
81
+ 'quant_algo': quant_algo
82
+ },
83
+ }
84
+
85
+ def remap_state_dict(original_state_dict):
86
+ new_state_dict = {}
87
+ for key, value in original_state_dict.items():
88
+ if key.startswith("decoder."):
89
+ new_key = key
90
+ # Top-level decoder module renames
91
+ new_key = new_key.replace("decoder.tgt_word_emb.", "decoder.token_embedding.")
92
+ new_key = new_key.replace("decoder.layer_stack.", "decoder.blocks.")
93
+ new_key = new_key.replace("decoder.layer_norm_out.", "decoder.ln.")
94
+ new_key = new_key.replace("decoder.tgt_word_prj.", "decoder.output_projection.")
95
+
96
+ # ResidualAttentionBlock internal layer renames
97
+ new_key = new_key.replace(".self_attn_norm.", ".attn_ln.")
98
+ new_key = new_key.replace(".self_attn.", ".attn.")
99
+ new_key = new_key.replace(".cross_attn_norm.", ".cross_attn_ln.")
100
+ new_key = new_key.replace(".cross_attn.", ".cross_attn.")
101
+ new_key = new_key.replace(".mlp_norm.", ".mlp_ln.")
102
+
103
+ # Inlined PositionwiseFeedForward renames
104
+ new_key = new_key.replace(".mlp.w_1.", ".mlp.0.")
105
+ new_key = new_key.replace(".mlp.w_2.", ".mlp.2.")
106
+
107
+ # MultiHeadAttention submodule renames
108
+ new_key = new_key.replace(".w_qs.", ".query.")
109
+ new_key = new_key.replace(".w_ks.", ".key.")
110
+ new_key = new_key.replace(".w_vs.", ".value.")
111
+ new_key = new_key.replace(".fc.", ".out.")
112
+
113
+ new_state_dict[new_key] = value
114
+
115
+ # Manually handle sinusoidal positional encoding -> learnable embedding
116
+ if "decoder.positional_encoding.pe" in original_state_dict:
117
+ new_state_dict["decoder.positional_embedding"] = original_state_dict["decoder.positional_encoding.pe"].squeeze(0)
118
+
119
+ return new_state_dict
120
+
121
+
122
+ def convert_firered_decoder(model_args, model_params, quant_algo: str = None):
123
+ weights = {}
124
+
125
+ # The original model shares embedding and projection weights.
126
+ # TRT-LLM's DecoderModel expects separate lm_head.weight
127
+ weights['transformer.vocab_embedding.weight'] = model_params['decoder.token_embedding.weight']
128
+ weights['lm_head.weight'] = model_params['decoder.output_projection.weight']
129
+ weights['transformer.position_embedding.weight'] = model_params['decoder.positional_embedding']
130
+
131
+ for i in range(model_args.n_layers_dec):
132
+ trtllm_layer_name_prefix = f'transformer.layers.{i}'
133
+
134
+ # Self Attention
135
+ q_w = model_params[f'decoder.blocks.{i}.attn.query.weight']
136
+ k_w = model_params[f'decoder.blocks.{i}.attn.key.weight']
137
+ v_w = model_params[f'decoder.blocks.{i}.attn.value.weight']
138
+ weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0)
139
+
140
+ q_b = model_params[f'decoder.blocks.{i}.attn.query.bias']
141
+ # The key projection has no bias in Whisper's MultiHeadAttention
142
+ k_b = torch.zeros_like(q_b)
143
+ v_b = model_params[f'decoder.blocks.{i}.attn.value.bias']
144
+ weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0)
145
+
146
+ weights[f'{trtllm_layer_name_prefix}.self_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.attn.out.weight']
147
+ weights[f'{trtllm_layer_name_prefix}.self_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.attn.out.bias']
148
+ weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.attn_ln.weight']
149
+ weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.attn_ln.bias']
150
+
151
+ # Cross Attention
152
+ q_w = model_params[f'decoder.blocks.{i}.cross_attn.query.weight']
153
+ k_w = model_params[f'decoder.blocks.{i}.cross_attn.key.weight']
154
+ v_w = model_params[f'decoder.blocks.{i}.cross_attn.value.weight']
155
+ weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0)
156
+
157
+ q_b = model_params[f'decoder.blocks.{i}.cross_attn.query.bias']
158
+ # The key projection has no bias in Whisper's MultiHeadAttention
159
+ k_b = torch.zeros_like(q_b)
160
+ v_b = model_params[f'decoder.blocks.{i}.cross_attn.value.bias']
161
+ weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0)
162
+
163
+ weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.cross_attn.out.weight']
164
+ weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.cross_attn.out.bias']
165
+ weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.weight']
166
+ weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.bias']
167
+
168
+ # MLP
169
+ weights[f'{trtllm_layer_name_prefix}.mlp.fc.weight'] = model_params[f'decoder.blocks.{i}.mlp.0.weight']
170
+ weights[f'{trtllm_layer_name_prefix}.mlp.fc.bias'] = model_params[f'decoder.blocks.{i}.mlp.0.bias']
171
+ weights[f'{trtllm_layer_name_prefix}.mlp.proj.weight'] = model_params[f'decoder.blocks.{i}.mlp.2.weight']
172
+ weights[f'{trtllm_layer_name_prefix}.mlp.proj.bias'] = model_params[f'decoder.blocks.{i}.mlp.2.bias']
173
+ weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.weight'] = model_params[f'decoder.blocks.{i}.mlp_ln.weight']
174
+ weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.bias'] = model_params[f'decoder.blocks.{i}.mlp_ln.bias']
175
+
176
+ weights['transformer.ln_f.weight'] = model_params['decoder.ln.weight']
177
+ weights['transformer.ln_f.bias'] = model_params['decoder.ln.bias']
178
+
179
+ if quant_algo is not None:
180
+ return weight_only_quantize_dict(weights, quant_algo=quant_algo)
181
+ return weights
182
+
183
+
184
+ if __name__ == '__main__':
185
+ print(f"Using TensorRT-LLM version: {tensorrt_llm.__version__}")
186
+ args = parse_arguments()
187
+ tik = time.time()
188
+
189
+ if not os.path.exists(args.output_dir):
190
+ os.makedirs(args.output_dir)
191
+
192
+ quant_algo = None
193
+ if args.use_weight_only and args.weight_only_precision == 'int8':
194
+ quant_algo = QuantAlgo.W8A16
195
+ elif args.use_weight_only and args.weight_only_precision == 'int4':
196
+ quant_algo = QuantAlgo.W4A16
197
+
198
+ # Load the original checkpoint
199
+ package = torch.load(args.model_path, map_location='cpu', weights_only=False)
200
+ model_args = package["args"]
201
+ original_state_dict = package["model_state_dict"]
202
+ print(f"Successfully loaded checkpoint from {args.model_path}")
203
+ print("Original model args:", model_args)
204
+
205
+ # Remap state dict keys for Whisper compatibility
206
+ remapped_state_dict = remap_state_dict(original_state_dict)
207
+
208
+ # Set tensor dtype
209
+ tensor_dtype = getattr(torch, args.dtype)
210
+ for key, value in remapped_state_dict.items():
211
+ remapped_state_dict[key] = value.to(tensor_dtype)
212
+
213
+ # Generate config and convert weights
214
+ print("Converting decoder checkpoint...")
215
+ decoder_config = get_decoder_config(model_args, args.dtype, args.logits_dtype, quant_algo)
216
+ decoder_weights = convert_firered_decoder(model_args, remapped_state_dict, quant_algo)
217
+
218
+ # Save the decoder config and weights
219
+ decoder_save_dir = os.path.join(args.output_dir, "decoder")
220
+ if not os.path.exists(decoder_save_dir):
221
+ os.makedirs(decoder_save_dir)
222
+
223
+ with open(os.path.join(decoder_save_dir, 'config.json'), 'w') as f:
224
+ json.dump(decoder_config, f, indent=4)
225
+
226
+ save_file(decoder_weights, os.path.join(decoder_save_dir, f'rank0.safetensors'))
227
+
228
+ tok = time.time()
229
+ t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
230
+ print(f'Checkpoint successfully converted and saved to {args.output_dir}.')
231
+ print(f'Total time of converting checkpoints: {t}')
encoder.fp16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:979d55f4cecfb651720b037802649f39acb6c235f048c62f7ddb8a1a30bebda8
3
+ size 1447173731
export_encoder_tensorrt.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
3
+ # Copyright 2025 Nvidia Corp. (authors: Yuekai Zhang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ This script exports a pre-trained FireRedASR encoder model from PyTorch to
21
+ ONNX and TensorRT.
22
+
23
+ Usage:
24
+
25
+ python3 examples/export_encoder_tensorrt.py \
26
+ --model-dir /path/to/your/model_dir \
27
+ --tensorrt-model-dir ./tensorrt_models \
28
+ --trt-engine-file-name encoder.plan
29
+ """
30
+
31
+ import argparse
32
+ import logging
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ import tensorrt as trt
37
+
38
+ from fireredasr.models.fireredasr import load_fireredasr_aed_model
39
+
40
+
41
+ def get_parser() -> argparse.ArgumentParser:
42
+ """Get the command-line argument parser."""
43
+ parser = argparse.ArgumentParser(
44
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--model-dir",
49
+ type=str,
50
+ default=None,
51
+ help="The model directory that contains model checkpoint.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--onnx-model-path",
56
+ type=str,
57
+ default=None,
58
+ help="If specified, we will directly use this onnx model to generate "
59
+ "the tensorrt engine",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--idim",
64
+ type=int,
65
+ default=80,
66
+ help="The input dimension of the model. This is required when "
67
+ "--onnx-model-path is specified.",
68
+ )
69
+
70
+ parser.add_argument(
71
+ "--tensorrt-model-dir",
72
+ type=str,
73
+ default="exp",
74
+ help="Directory to save the exported models.",
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--trt-engine-file-name",
79
+ type=str,
80
+ default="encoder.plan",
81
+ help="The name of the TensorRT engine file.",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--opset-version",
86
+ type=int,
87
+ default=17,
88
+ help="ONNX opset version.",
89
+ )
90
+
91
+ return parser
92
+
93
+
94
+ def export_encoder_onnx(
95
+ encoder: torch.nn.Module,
96
+ filename: str,
97
+ idim: int,
98
+ opset_version: int = 17,
99
+ ) -> None:
100
+ """Export the conformer encoder model to ONNX format."""
101
+ logging.info("Exporting encoder to ONNX")
102
+ encoder.half()
103
+
104
+ # Create dummy inputs
105
+ seq_len = 400 # A typical sequence length
106
+ batch_size = 1
107
+ padded_input = torch.randn(batch_size, seq_len, idim, dtype=torch.float16)
108
+ input_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32)
109
+
110
+ # Export
111
+ torch.onnx.export(
112
+ encoder,
113
+ (padded_input, input_lengths),
114
+ filename,
115
+ opset_version=opset_version,
116
+ input_names=["padded_input", "input_lengths"],
117
+ output_names=["enc_output", "output_lengths", "src_mask"],
118
+ dynamic_axes={
119
+ "padded_input": {0: "batch_size", 1: "seq_len"},
120
+ "input_lengths": {0: "batch_size"},
121
+ "enc_output": {0: "batch_size", 1: "seq_len_out"},
122
+ "output_lengths": {0: "batch_size",},
123
+ "src_mask": {0: "batch_size", 2: "seq_len_out"},
124
+ },
125
+ )
126
+ logging.info(f"Exported encoder to {filename}")
127
+
128
+
129
+ def get_trt_kwargs_dynamic_batch(
130
+ idim: int,
131
+ min_batch_size: int = 1,
132
+ opt_batch_size: int = 4,
133
+ max_batch_size: int = 64,
134
+ ):
135
+ """Get keyword arguments for TensorRT with dynamic batch size."""
136
+ min_seq_len = 50
137
+ opt_seq_len = 400
138
+ max_seq_len = 3000
139
+
140
+ min_shape = [(min_batch_size, min_seq_len, idim), (min_batch_size,)]
141
+ opt_shape = [(opt_batch_size, opt_seq_len, idim), (opt_batch_size,)]
142
+ max_shape = [(max_batch_size, max_seq_len, idim), (max_batch_size,)]
143
+ input_names = ["padded_input", "input_lengths"]
144
+ return {
145
+ "min_shape": min_shape,
146
+ "opt_shape": opt_shape,
147
+ "max_shape": max_shape,
148
+ "input_names": input_names,
149
+ }
150
+
151
+
152
+ def convert_onnx_to_trt(
153
+ trt_model: str, trt_kwargs: dict, onnx_model: str, dtype: torch.dtype = torch.float16
154
+ ) -> None:
155
+ """Convert an ONNX model to a TensorRT engine."""
156
+ logging.info("Converting ONNX to TensorRT engine...")
157
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
158
+ logger = trt.Logger(trt.Logger.INFO)
159
+ builder = trt.Builder(logger)
160
+ network = builder.create_network(network_flags)
161
+ parser = trt.OnnxParser(network, logger)
162
+ config = builder.create_builder_config()
163
+
164
+ if dtype == torch.float16:
165
+ config.set_flag(trt.BuilderFlag.FP16)
166
+
167
+ profile = builder.create_optimization_profile()
168
+
169
+ with open(onnx_model, "rb") as f:
170
+ if not parser.parse(f.read()):
171
+ for error in range(parser.num_errors):
172
+ print(parser.get_error(error))
173
+ raise ValueError(f'Failed to parse {onnx_model}')
174
+
175
+ for i, name in enumerate(trt_kwargs['input_names']):
176
+ profile.set_shape(
177
+ name,
178
+ trt_kwargs['min_shape'][i],
179
+ trt_kwargs['opt_shape'][i],
180
+ trt_kwargs['max_shape'][i]
181
+ )
182
+
183
+ config.add_optimization_profile(profile)
184
+
185
+ try:
186
+ engine_bytes = builder.build_serialized_network(network, config)
187
+ except Exception as e:
188
+ logging.error(f"TensorRT engine build failed: {e}")
189
+ return
190
+
191
+ with open(trt_model, "wb") as f:
192
+ f.write(engine_bytes)
193
+ logging.info("Successfully converted ONNX to TensorRT.")
194
+
195
+
196
+ @torch.no_grad()
197
+ def main():
198
+ """Main function to export the model."""
199
+ parser = get_parser()
200
+ args = parser.parse_args()
201
+
202
+ tensorrt_model_dir = Path(args.tensorrt_model_dir)
203
+ tensorrt_model_dir.mkdir(parents=True, exist_ok=True)
204
+
205
+ if args.onnx_model_path:
206
+ logging.info(f"Using provided ONNX model: {args.onnx_model_path}")
207
+ if not args.idim:
208
+ raise ValueError("--idim is required when using --onnx-model-path")
209
+ idim = args.idim
210
+ encoder_onnx_file = Path(args.onnx_model_path)
211
+ if not encoder_onnx_file.is_file():
212
+ raise FileNotFoundError(f"ONNX model not found at {encoder_onnx_file}")
213
+ else:
214
+ if not args.model_dir:
215
+ raise ValueError(
216
+ "--model-dir is required if --onnx-model-path is not provided"
217
+ )
218
+
219
+ logging.info("Exporting ONNX model from PyTorch checkpoint")
220
+ model_dir = Path(args.model_dir)
221
+ model_path = model_dir / "model.pth.tar"
222
+
223
+ # Load model to get encoder
224
+ package = torch.load(model_path, map_location="cpu", weights_only=False)
225
+ model_args = package["args"]
226
+ idim = model_args.idim
227
+ # We have to load the full AED model to get the encoder with weights
228
+ model = load_fireredasr_aed_model(str(model_path))
229
+ encoder = model.encoder
230
+ encoder.eval()
231
+
232
+ # Export ONNX
233
+ encoder_onnx_file = tensorrt_model_dir / "encoder.fp16.onnx"
234
+ export_encoder_onnx(
235
+ encoder=encoder,
236
+ filename=str(encoder_onnx_file),
237
+ idim=idim,
238
+ opset_version=args.opset_version,
239
+ )
240
+
241
+ # Convert ONNX to TensorRT
242
+ trt_engine_file = tensorrt_model_dir / args.trt_engine_file_name
243
+ trt_kwargs = get_trt_kwargs_dynamic_batch(idim=idim)
244
+ convert_onnx_to_trt(
245
+ trt_model=str(trt_engine_file),
246
+ trt_kwargs=trt_kwargs,
247
+ onnx_model=str(encoder_onnx_file),
248
+ dtype=torch.float16,
249
+ )
250
+
251
+ logging.info("Done!")
252
+
253
+
254
+ if __name__ == "__main__":
255
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
256
+ logging.basicConfig(format=formatter, level=logging.INFO)
257
+ main()
export_tensorrt.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
3
+ export PYTHONPATH=$PWD/:$PYTHONPATH
4
+
5
+ # model_path=pretrained_models/FireRedASR-AED-L
6
+ # python3 export_encoder_tensorrt.py \
7
+ # --model-dir $model_path \
8
+ # --tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
9
+ # --trt-engine-file-name encoder.plan
10
+
11
+ TRT_ENGINE_OUTPUT_DIR=./FireRedASR-AED-L-TensorRT
12
+ python3 export_encoder_tensorrt.py \
13
+ --onnx-model-path $TRT_ENGINE_OUTPUT_DIR/encoder.fp16.onnx \
14
+ --tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
15
+ --trt-engine-file-name encoder.plan
16
+
17
+
18
+ INFERENCE_PRECISION=float16
19
+ MAX_BEAM_WIDTH=4
20
+ MAX_BATCH_SIZE=64
21
+ checkpoint_dir=$TRT_ENGINE_OUTPUT_DIR/tllm_checkpoint_float16
22
+ output_dir=$TRT_ENGINE_OUTPUT_DIR/trt_engine_${INFERENCE_PRECISION}
23
+
24
+ # model_path=pretrained_models/FireRedASR-AED-L/model.pth.tar
25
+ # python3 convert_checkpoint.py \
26
+ # --dtype ${INFERENCE_PRECISION} \
27
+ # --model_path $model_path \
28
+ # --output_dir $checkpoint_dir
29
+
30
+ trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
31
+ --output_dir ${output_dir}/decoder \
32
+ --moe_plugin disable \
33
+ --max_beam_width ${MAX_BEAM_WIDTH} \
34
+ --max_batch_size ${MAX_BATCH_SIZE} \
35
+ --max_seq_len 512 \
36
+ --max_input_len 4 \
37
+ --max_encoder_input_len 1024 \
38
+ --gemm_plugin ${INFERENCE_PRECISION} \
39
+ --remove_input_padding disable \
40
+ --paged_kv_cache disable \
41
+ --gpt_attention_plugin ${INFERENCE_PRECISION}
42
+
43
+
44
+ # FireRedASR-AED-L-TensorRT/
45
+ # β”œβ”€β”€ encoder.fp16.onnx
46
+ # β”œβ”€β”€ encoder.plan
47
+ # β”œβ”€β”€ tllm_checkpoint_float16
48
+ # β”‚Β Β  └── decoder
49
+ # β”‚Β Β  β”œβ”€β”€ config.json
50
+ # β”‚Β Β  └── rank0.safetensors
51
+ # └── trt_engine_float16
52
+ # └── decoder
53
+ # β”œβ”€β”€ config.json
54
+ # └── rank0.engine
tllm_checkpoint_float16/decoder/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "DecoderModel",
3
+ "dtype": "float16",
4
+ "logits_dtype": "float16",
5
+ "num_hidden_layers": 16,
6
+ "num_attention_heads": 20,
7
+ "hidden_size": 1280,
8
+ "norm_epsilon": 1e-05,
9
+ "vocab_size": 7832,
10
+ "hidden_act": "gelu",
11
+ "use_parallel_embedding": false,
12
+ "embedding_sharding_dim": 0,
13
+ "max_position_embeddings": 5000,
14
+ "use_prompt_tuning": false,
15
+ "head_size": 64,
16
+ "has_position_embedding": true,
17
+ "layernorm_type": 0,
18
+ "has_attention_qkvo_bias": true,
19
+ "has_mlp_bias": true,
20
+ "has_model_final_layernorm": true,
21
+ "has_embedding_layernorm": false,
22
+ "has_embedding_scale": true,
23
+ "ffn_hidden_size": 5120,
24
+ "q_scaling": 1.0,
25
+ "layernorm_position": 0,
26
+ "relative_attention": false,
27
+ "max_distance": 0,
28
+ "num_buckets": 0,
29
+ "model_type": "whisper",
30
+ "rescale_before_lm_head": false,
31
+ "encoder_hidden_size": 1280,
32
+ "encoder_num_heads": 20,
33
+ "encoder_head_size": null,
34
+ "skip_cross_kv": false,
35
+ "quantization": {
36
+ "quant_algo": null
37
+ }
38
+ }
tllm_checkpoint_float16/decoder/rank0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fae4a3ce0ab15552d307ef960a579c25f479d490b65959cf4189e7a723463037
3
+ size 892578184