Upload folder using huggingface_hub
Browse files- convert_checkpoint.py +231 -0
- encoder.fp16.onnx +3 -0
- export_encoder_tensorrt.py +257 -0
- export_tensorrt.sh +54 -0
- tllm_checkpoint_float16/decoder/config.json +38 -0
- tllm_checkpoint_float16/decoder/rank0.safetensors +3 -0
convert_checkpoint.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from safetensors.torch import save_file
|
| 8 |
+
|
| 9 |
+
import tensorrt_llm
|
| 10 |
+
from tensorrt_llm.functional import LayerNormPositionType, LayerNormType
|
| 11 |
+
from tensorrt_llm.models.convert_utils import weight_only_quantize_dict
|
| 12 |
+
from tensorrt_llm.quantization import QuantAlgo
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_arguments():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument('--model_path', type=str, required=True,
|
| 18 |
+
help="Path to the FireRedASR model.pth.tar checkpoint.")
|
| 19 |
+
parser.add_argument('--output_dir', type=str, default='tllm_checkpoint',
|
| 20 |
+
help='The path to save the TensorRT-LLM checkpoint')
|
| 21 |
+
parser.add_argument('--dtype', type=str, default='float16',
|
| 22 |
+
choices=['float32', 'bfloat16', 'float16'])
|
| 23 |
+
parser.add_argument('--logits_dtype', type=str, default='float16',
|
| 24 |
+
choices=['float16', 'float32'])
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
'--use_weight_only',
|
| 27 |
+
default=False,
|
| 28 |
+
action="store_true",
|
| 29 |
+
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
| 30 |
+
'See --weight_only_precision to set the precision')
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
'--weight_only_precision',
|
| 33 |
+
const='int8',
|
| 34 |
+
type=str,
|
| 35 |
+
nargs='?',
|
| 36 |
+
default='int8',
|
| 37 |
+
choices=['int8', 'int4'],
|
| 38 |
+
help=
|
| 39 |
+
'Define the precision for the weights when using weight-only quantization.'
|
| 40 |
+
'You must also use --use_weight_only for that argument to have an impact.'
|
| 41 |
+
)
|
| 42 |
+
return parser.parse_args()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_decoder_config(model_args, dtype: str, logits_dtype: str, quant_algo: QuantAlgo) -> dict:
|
| 46 |
+
return {
|
| 47 |
+
'architecture': "DecoderModel",
|
| 48 |
+
'dtype': dtype,
|
| 49 |
+
'logits_dtype': logits_dtype,
|
| 50 |
+
'num_hidden_layers': model_args.n_layers_dec,
|
| 51 |
+
'num_attention_heads': model_args.n_head,
|
| 52 |
+
'hidden_size': model_args.d_model,
|
| 53 |
+
'norm_epsilon': 1e-5,
|
| 54 |
+
'vocab_size': model_args.odim,
|
| 55 |
+
'hidden_act': "gelu",
|
| 56 |
+
'use_parallel_embedding': False,
|
| 57 |
+
'embedding_sharding_dim': 0,
|
| 58 |
+
'max_position_embeddings': model_args.pe_maxlen,
|
| 59 |
+
'use_prompt_tuning': False,
|
| 60 |
+
'head_size': model_args.d_model // model_args.n_head,
|
| 61 |
+
'has_position_embedding': True,
|
| 62 |
+
'layernorm_type': LayerNormType.LayerNorm,
|
| 63 |
+
'has_attention_qkvo_bias': True,
|
| 64 |
+
'has_mlp_bias': True,
|
| 65 |
+
'has_model_final_layernorm': True,
|
| 66 |
+
'has_embedding_layernorm': False,
|
| 67 |
+
'has_embedding_scale': True, # FireRedASR scales the embedding
|
| 68 |
+
'ffn_hidden_size': 4 * model_args.d_model,
|
| 69 |
+
'q_scaling': 1.0,
|
| 70 |
+
'layernorm_position': LayerNormPositionType.pre_layernorm,
|
| 71 |
+
'relative_attention': False,
|
| 72 |
+
'max_distance': 0,
|
| 73 |
+
'num_buckets': 0,
|
| 74 |
+
'model_type': 'whisper', # To align with Whisper decoder architecture in TRT-LLM
|
| 75 |
+
'rescale_before_lm_head': False,
|
| 76 |
+
'encoder_hidden_size': model_args.d_model,
|
| 77 |
+
'encoder_num_heads': model_args.n_head,
|
| 78 |
+
'encoder_head_size': None,
|
| 79 |
+
'skip_cross_kv': False,
|
| 80 |
+
'quantization': {
|
| 81 |
+
'quant_algo': quant_algo
|
| 82 |
+
},
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
def remap_state_dict(original_state_dict):
|
| 86 |
+
new_state_dict = {}
|
| 87 |
+
for key, value in original_state_dict.items():
|
| 88 |
+
if key.startswith("decoder."):
|
| 89 |
+
new_key = key
|
| 90 |
+
# Top-level decoder module renames
|
| 91 |
+
new_key = new_key.replace("decoder.tgt_word_emb.", "decoder.token_embedding.")
|
| 92 |
+
new_key = new_key.replace("decoder.layer_stack.", "decoder.blocks.")
|
| 93 |
+
new_key = new_key.replace("decoder.layer_norm_out.", "decoder.ln.")
|
| 94 |
+
new_key = new_key.replace("decoder.tgt_word_prj.", "decoder.output_projection.")
|
| 95 |
+
|
| 96 |
+
# ResidualAttentionBlock internal layer renames
|
| 97 |
+
new_key = new_key.replace(".self_attn_norm.", ".attn_ln.")
|
| 98 |
+
new_key = new_key.replace(".self_attn.", ".attn.")
|
| 99 |
+
new_key = new_key.replace(".cross_attn_norm.", ".cross_attn_ln.")
|
| 100 |
+
new_key = new_key.replace(".cross_attn.", ".cross_attn.")
|
| 101 |
+
new_key = new_key.replace(".mlp_norm.", ".mlp_ln.")
|
| 102 |
+
|
| 103 |
+
# Inlined PositionwiseFeedForward renames
|
| 104 |
+
new_key = new_key.replace(".mlp.w_1.", ".mlp.0.")
|
| 105 |
+
new_key = new_key.replace(".mlp.w_2.", ".mlp.2.")
|
| 106 |
+
|
| 107 |
+
# MultiHeadAttention submodule renames
|
| 108 |
+
new_key = new_key.replace(".w_qs.", ".query.")
|
| 109 |
+
new_key = new_key.replace(".w_ks.", ".key.")
|
| 110 |
+
new_key = new_key.replace(".w_vs.", ".value.")
|
| 111 |
+
new_key = new_key.replace(".fc.", ".out.")
|
| 112 |
+
|
| 113 |
+
new_state_dict[new_key] = value
|
| 114 |
+
|
| 115 |
+
# Manually handle sinusoidal positional encoding -> learnable embedding
|
| 116 |
+
if "decoder.positional_encoding.pe" in original_state_dict:
|
| 117 |
+
new_state_dict["decoder.positional_embedding"] = original_state_dict["decoder.positional_encoding.pe"].squeeze(0)
|
| 118 |
+
|
| 119 |
+
return new_state_dict
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def convert_firered_decoder(model_args, model_params, quant_algo: str = None):
|
| 123 |
+
weights = {}
|
| 124 |
+
|
| 125 |
+
# The original model shares embedding and projection weights.
|
| 126 |
+
# TRT-LLM's DecoderModel expects separate lm_head.weight
|
| 127 |
+
weights['transformer.vocab_embedding.weight'] = model_params['decoder.token_embedding.weight']
|
| 128 |
+
weights['lm_head.weight'] = model_params['decoder.output_projection.weight']
|
| 129 |
+
weights['transformer.position_embedding.weight'] = model_params['decoder.positional_embedding']
|
| 130 |
+
|
| 131 |
+
for i in range(model_args.n_layers_dec):
|
| 132 |
+
trtllm_layer_name_prefix = f'transformer.layers.{i}'
|
| 133 |
+
|
| 134 |
+
# Self Attention
|
| 135 |
+
q_w = model_params[f'decoder.blocks.{i}.attn.query.weight']
|
| 136 |
+
k_w = model_params[f'decoder.blocks.{i}.attn.key.weight']
|
| 137 |
+
v_w = model_params[f'decoder.blocks.{i}.attn.value.weight']
|
| 138 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0)
|
| 139 |
+
|
| 140 |
+
q_b = model_params[f'decoder.blocks.{i}.attn.query.bias']
|
| 141 |
+
# The key projection has no bias in Whisper's MultiHeadAttention
|
| 142 |
+
k_b = torch.zeros_like(q_b)
|
| 143 |
+
v_b = model_params[f'decoder.blocks.{i}.attn.value.bias']
|
| 144 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0)
|
| 145 |
+
|
| 146 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.attn.out.weight']
|
| 147 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.attn.out.bias']
|
| 148 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.attn_ln.weight']
|
| 149 |
+
weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.attn_ln.bias']
|
| 150 |
+
|
| 151 |
+
# Cross Attention
|
| 152 |
+
q_w = model_params[f'decoder.blocks.{i}.cross_attn.query.weight']
|
| 153 |
+
k_w = model_params[f'decoder.blocks.{i}.cross_attn.key.weight']
|
| 154 |
+
v_w = model_params[f'decoder.blocks.{i}.cross_attn.value.weight']
|
| 155 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0)
|
| 156 |
+
|
| 157 |
+
q_b = model_params[f'decoder.blocks.{i}.cross_attn.query.bias']
|
| 158 |
+
# The key projection has no bias in Whisper's MultiHeadAttention
|
| 159 |
+
k_b = torch.zeros_like(q_b)
|
| 160 |
+
v_b = model_params[f'decoder.blocks.{i}.cross_attn.value.bias']
|
| 161 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0)
|
| 162 |
+
|
| 163 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.cross_attn.out.weight']
|
| 164 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.cross_attn.out.bias']
|
| 165 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.weight']
|
| 166 |
+
weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.bias']
|
| 167 |
+
|
| 168 |
+
# MLP
|
| 169 |
+
weights[f'{trtllm_layer_name_prefix}.mlp.fc.weight'] = model_params[f'decoder.blocks.{i}.mlp.0.weight']
|
| 170 |
+
weights[f'{trtllm_layer_name_prefix}.mlp.fc.bias'] = model_params[f'decoder.blocks.{i}.mlp.0.bias']
|
| 171 |
+
weights[f'{trtllm_layer_name_prefix}.mlp.proj.weight'] = model_params[f'decoder.blocks.{i}.mlp.2.weight']
|
| 172 |
+
weights[f'{trtllm_layer_name_prefix}.mlp.proj.bias'] = model_params[f'decoder.blocks.{i}.mlp.2.bias']
|
| 173 |
+
weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.weight'] = model_params[f'decoder.blocks.{i}.mlp_ln.weight']
|
| 174 |
+
weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.bias'] = model_params[f'decoder.blocks.{i}.mlp_ln.bias']
|
| 175 |
+
|
| 176 |
+
weights['transformer.ln_f.weight'] = model_params['decoder.ln.weight']
|
| 177 |
+
weights['transformer.ln_f.bias'] = model_params['decoder.ln.bias']
|
| 178 |
+
|
| 179 |
+
if quant_algo is not None:
|
| 180 |
+
return weight_only_quantize_dict(weights, quant_algo=quant_algo)
|
| 181 |
+
return weights
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == '__main__':
|
| 185 |
+
print(f"Using TensorRT-LLM version: {tensorrt_llm.__version__}")
|
| 186 |
+
args = parse_arguments()
|
| 187 |
+
tik = time.time()
|
| 188 |
+
|
| 189 |
+
if not os.path.exists(args.output_dir):
|
| 190 |
+
os.makedirs(args.output_dir)
|
| 191 |
+
|
| 192 |
+
quant_algo = None
|
| 193 |
+
if args.use_weight_only and args.weight_only_precision == 'int8':
|
| 194 |
+
quant_algo = QuantAlgo.W8A16
|
| 195 |
+
elif args.use_weight_only and args.weight_only_precision == 'int4':
|
| 196 |
+
quant_algo = QuantAlgo.W4A16
|
| 197 |
+
|
| 198 |
+
# Load the original checkpoint
|
| 199 |
+
package = torch.load(args.model_path, map_location='cpu', weights_only=False)
|
| 200 |
+
model_args = package["args"]
|
| 201 |
+
original_state_dict = package["model_state_dict"]
|
| 202 |
+
print(f"Successfully loaded checkpoint from {args.model_path}")
|
| 203 |
+
print("Original model args:", model_args)
|
| 204 |
+
|
| 205 |
+
# Remap state dict keys for Whisper compatibility
|
| 206 |
+
remapped_state_dict = remap_state_dict(original_state_dict)
|
| 207 |
+
|
| 208 |
+
# Set tensor dtype
|
| 209 |
+
tensor_dtype = getattr(torch, args.dtype)
|
| 210 |
+
for key, value in remapped_state_dict.items():
|
| 211 |
+
remapped_state_dict[key] = value.to(tensor_dtype)
|
| 212 |
+
|
| 213 |
+
# Generate config and convert weights
|
| 214 |
+
print("Converting decoder checkpoint...")
|
| 215 |
+
decoder_config = get_decoder_config(model_args, args.dtype, args.logits_dtype, quant_algo)
|
| 216 |
+
decoder_weights = convert_firered_decoder(model_args, remapped_state_dict, quant_algo)
|
| 217 |
+
|
| 218 |
+
# Save the decoder config and weights
|
| 219 |
+
decoder_save_dir = os.path.join(args.output_dir, "decoder")
|
| 220 |
+
if not os.path.exists(decoder_save_dir):
|
| 221 |
+
os.makedirs(decoder_save_dir)
|
| 222 |
+
|
| 223 |
+
with open(os.path.join(decoder_save_dir, 'config.json'), 'w') as f:
|
| 224 |
+
json.dump(decoder_config, f, indent=4)
|
| 225 |
+
|
| 226 |
+
save_file(decoder_weights, os.path.join(decoder_save_dir, f'rank0.safetensors'))
|
| 227 |
+
|
| 228 |
+
tok = time.time()
|
| 229 |
+
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
| 230 |
+
print(f'Checkpoint successfully converted and saved to {args.output_dir}.')
|
| 231 |
+
print(f'Total time of converting checkpoints: {t}')
|
encoder.fp16.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:979d55f4cecfb651720b037802649f39acb6c235f048c62f7ddb8a1a30bebda8
|
| 3 |
+
size 1447173731
|
export_encoder_tensorrt.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
|
| 3 |
+
# Copyright 2025 Nvidia Corp. (authors: Yuekai Zhang)
|
| 4 |
+
#
|
| 5 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
|
| 19 |
+
"""
|
| 20 |
+
This script exports a pre-trained FireRedASR encoder model from PyTorch to
|
| 21 |
+
ONNX and TensorRT.
|
| 22 |
+
|
| 23 |
+
Usage:
|
| 24 |
+
|
| 25 |
+
python3 examples/export_encoder_tensorrt.py \
|
| 26 |
+
--model-dir /path/to/your/model_dir \
|
| 27 |
+
--tensorrt-model-dir ./tensorrt_models \
|
| 28 |
+
--trt-engine-file-name encoder.plan
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import argparse
|
| 32 |
+
import logging
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import tensorrt as trt
|
| 37 |
+
|
| 38 |
+
from fireredasr.models.fireredasr import load_fireredasr_aed_model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_parser() -> argparse.ArgumentParser:
|
| 42 |
+
"""Get the command-line argument parser."""
|
| 43 |
+
parser = argparse.ArgumentParser(
|
| 44 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--model-dir",
|
| 49 |
+
type=str,
|
| 50 |
+
default=None,
|
| 51 |
+
help="The model directory that contains model checkpoint.",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--onnx-model-path",
|
| 56 |
+
type=str,
|
| 57 |
+
default=None,
|
| 58 |
+
help="If specified, we will directly use this onnx model to generate "
|
| 59 |
+
"the tensorrt engine",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--idim",
|
| 64 |
+
type=int,
|
| 65 |
+
default=80,
|
| 66 |
+
help="The input dimension of the model. This is required when "
|
| 67 |
+
"--onnx-model-path is specified.",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--tensorrt-model-dir",
|
| 72 |
+
type=str,
|
| 73 |
+
default="exp",
|
| 74 |
+
help="Directory to save the exported models.",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--trt-engine-file-name",
|
| 79 |
+
type=str,
|
| 80 |
+
default="encoder.plan",
|
| 81 |
+
help="The name of the TensorRT engine file.",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--opset-version",
|
| 86 |
+
type=int,
|
| 87 |
+
default=17,
|
| 88 |
+
help="ONNX opset version.",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return parser
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def export_encoder_onnx(
|
| 95 |
+
encoder: torch.nn.Module,
|
| 96 |
+
filename: str,
|
| 97 |
+
idim: int,
|
| 98 |
+
opset_version: int = 17,
|
| 99 |
+
) -> None:
|
| 100 |
+
"""Export the conformer encoder model to ONNX format."""
|
| 101 |
+
logging.info("Exporting encoder to ONNX")
|
| 102 |
+
encoder.half()
|
| 103 |
+
|
| 104 |
+
# Create dummy inputs
|
| 105 |
+
seq_len = 400 # A typical sequence length
|
| 106 |
+
batch_size = 1
|
| 107 |
+
padded_input = torch.randn(batch_size, seq_len, idim, dtype=torch.float16)
|
| 108 |
+
input_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32)
|
| 109 |
+
|
| 110 |
+
# Export
|
| 111 |
+
torch.onnx.export(
|
| 112 |
+
encoder,
|
| 113 |
+
(padded_input, input_lengths),
|
| 114 |
+
filename,
|
| 115 |
+
opset_version=opset_version,
|
| 116 |
+
input_names=["padded_input", "input_lengths"],
|
| 117 |
+
output_names=["enc_output", "output_lengths", "src_mask"],
|
| 118 |
+
dynamic_axes={
|
| 119 |
+
"padded_input": {0: "batch_size", 1: "seq_len"},
|
| 120 |
+
"input_lengths": {0: "batch_size"},
|
| 121 |
+
"enc_output": {0: "batch_size", 1: "seq_len_out"},
|
| 122 |
+
"output_lengths": {0: "batch_size",},
|
| 123 |
+
"src_mask": {0: "batch_size", 2: "seq_len_out"},
|
| 124 |
+
},
|
| 125 |
+
)
|
| 126 |
+
logging.info(f"Exported encoder to {filename}")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_trt_kwargs_dynamic_batch(
|
| 130 |
+
idim: int,
|
| 131 |
+
min_batch_size: int = 1,
|
| 132 |
+
opt_batch_size: int = 4,
|
| 133 |
+
max_batch_size: int = 64,
|
| 134 |
+
):
|
| 135 |
+
"""Get keyword arguments for TensorRT with dynamic batch size."""
|
| 136 |
+
min_seq_len = 50
|
| 137 |
+
opt_seq_len = 400
|
| 138 |
+
max_seq_len = 3000
|
| 139 |
+
|
| 140 |
+
min_shape = [(min_batch_size, min_seq_len, idim), (min_batch_size,)]
|
| 141 |
+
opt_shape = [(opt_batch_size, opt_seq_len, idim), (opt_batch_size,)]
|
| 142 |
+
max_shape = [(max_batch_size, max_seq_len, idim), (max_batch_size,)]
|
| 143 |
+
input_names = ["padded_input", "input_lengths"]
|
| 144 |
+
return {
|
| 145 |
+
"min_shape": min_shape,
|
| 146 |
+
"opt_shape": opt_shape,
|
| 147 |
+
"max_shape": max_shape,
|
| 148 |
+
"input_names": input_names,
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def convert_onnx_to_trt(
|
| 153 |
+
trt_model: str, trt_kwargs: dict, onnx_model: str, dtype: torch.dtype = torch.float16
|
| 154 |
+
) -> None:
|
| 155 |
+
"""Convert an ONNX model to a TensorRT engine."""
|
| 156 |
+
logging.info("Converting ONNX to TensorRT engine...")
|
| 157 |
+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 158 |
+
logger = trt.Logger(trt.Logger.INFO)
|
| 159 |
+
builder = trt.Builder(logger)
|
| 160 |
+
network = builder.create_network(network_flags)
|
| 161 |
+
parser = trt.OnnxParser(network, logger)
|
| 162 |
+
config = builder.create_builder_config()
|
| 163 |
+
|
| 164 |
+
if dtype == torch.float16:
|
| 165 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 166 |
+
|
| 167 |
+
profile = builder.create_optimization_profile()
|
| 168 |
+
|
| 169 |
+
with open(onnx_model, "rb") as f:
|
| 170 |
+
if not parser.parse(f.read()):
|
| 171 |
+
for error in range(parser.num_errors):
|
| 172 |
+
print(parser.get_error(error))
|
| 173 |
+
raise ValueError(f'Failed to parse {onnx_model}')
|
| 174 |
+
|
| 175 |
+
for i, name in enumerate(trt_kwargs['input_names']):
|
| 176 |
+
profile.set_shape(
|
| 177 |
+
name,
|
| 178 |
+
trt_kwargs['min_shape'][i],
|
| 179 |
+
trt_kwargs['opt_shape'][i],
|
| 180 |
+
trt_kwargs['max_shape'][i]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
config.add_optimization_profile(profile)
|
| 184 |
+
|
| 185 |
+
try:
|
| 186 |
+
engine_bytes = builder.build_serialized_network(network, config)
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logging.error(f"TensorRT engine build failed: {e}")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
with open(trt_model, "wb") as f:
|
| 192 |
+
f.write(engine_bytes)
|
| 193 |
+
logging.info("Successfully converted ONNX to TensorRT.")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def main():
|
| 198 |
+
"""Main function to export the model."""
|
| 199 |
+
parser = get_parser()
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
tensorrt_model_dir = Path(args.tensorrt_model_dir)
|
| 203 |
+
tensorrt_model_dir.mkdir(parents=True, exist_ok=True)
|
| 204 |
+
|
| 205 |
+
if args.onnx_model_path:
|
| 206 |
+
logging.info(f"Using provided ONNX model: {args.onnx_model_path}")
|
| 207 |
+
if not args.idim:
|
| 208 |
+
raise ValueError("--idim is required when using --onnx-model-path")
|
| 209 |
+
idim = args.idim
|
| 210 |
+
encoder_onnx_file = Path(args.onnx_model_path)
|
| 211 |
+
if not encoder_onnx_file.is_file():
|
| 212 |
+
raise FileNotFoundError(f"ONNX model not found at {encoder_onnx_file}")
|
| 213 |
+
else:
|
| 214 |
+
if not args.model_dir:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
"--model-dir is required if --onnx-model-path is not provided"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
logging.info("Exporting ONNX model from PyTorch checkpoint")
|
| 220 |
+
model_dir = Path(args.model_dir)
|
| 221 |
+
model_path = model_dir / "model.pth.tar"
|
| 222 |
+
|
| 223 |
+
# Load model to get encoder
|
| 224 |
+
package = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 225 |
+
model_args = package["args"]
|
| 226 |
+
idim = model_args.idim
|
| 227 |
+
# We have to load the full AED model to get the encoder with weights
|
| 228 |
+
model = load_fireredasr_aed_model(str(model_path))
|
| 229 |
+
encoder = model.encoder
|
| 230 |
+
encoder.eval()
|
| 231 |
+
|
| 232 |
+
# Export ONNX
|
| 233 |
+
encoder_onnx_file = tensorrt_model_dir / "encoder.fp16.onnx"
|
| 234 |
+
export_encoder_onnx(
|
| 235 |
+
encoder=encoder,
|
| 236 |
+
filename=str(encoder_onnx_file),
|
| 237 |
+
idim=idim,
|
| 238 |
+
opset_version=args.opset_version,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Convert ONNX to TensorRT
|
| 242 |
+
trt_engine_file = tensorrt_model_dir / args.trt_engine_file_name
|
| 243 |
+
trt_kwargs = get_trt_kwargs_dynamic_batch(idim=idim)
|
| 244 |
+
convert_onnx_to_trt(
|
| 245 |
+
trt_model=str(trt_engine_file),
|
| 246 |
+
trt_kwargs=trt_kwargs,
|
| 247 |
+
onnx_model=str(encoder_onnx_file),
|
| 248 |
+
dtype=torch.float16,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
logging.info("Done!")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
| 256 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
| 257 |
+
main()
|
export_tensorrt.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
|
| 3 |
+
export PYTHONPATH=$PWD/:$PYTHONPATH
|
| 4 |
+
|
| 5 |
+
# model_path=pretrained_models/FireRedASR-AED-L
|
| 6 |
+
# python3 export_encoder_tensorrt.py \
|
| 7 |
+
# --model-dir $model_path \
|
| 8 |
+
# --tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
|
| 9 |
+
# --trt-engine-file-name encoder.plan
|
| 10 |
+
|
| 11 |
+
TRT_ENGINE_OUTPUT_DIR=./FireRedASR-AED-L-TensorRT
|
| 12 |
+
python3 export_encoder_tensorrt.py \
|
| 13 |
+
--onnx-model-path $TRT_ENGINE_OUTPUT_DIR/encoder.fp16.onnx \
|
| 14 |
+
--tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
|
| 15 |
+
--trt-engine-file-name encoder.plan
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
INFERENCE_PRECISION=float16
|
| 19 |
+
MAX_BEAM_WIDTH=4
|
| 20 |
+
MAX_BATCH_SIZE=64
|
| 21 |
+
checkpoint_dir=$TRT_ENGINE_OUTPUT_DIR/tllm_checkpoint_float16
|
| 22 |
+
output_dir=$TRT_ENGINE_OUTPUT_DIR/trt_engine_${INFERENCE_PRECISION}
|
| 23 |
+
|
| 24 |
+
# model_path=pretrained_models/FireRedASR-AED-L/model.pth.tar
|
| 25 |
+
# python3 convert_checkpoint.py \
|
| 26 |
+
# --dtype ${INFERENCE_PRECISION} \
|
| 27 |
+
# --model_path $model_path \
|
| 28 |
+
# --output_dir $checkpoint_dir
|
| 29 |
+
|
| 30 |
+
trtllm-build --checkpoint_dir ${checkpoint_dir}/decoder \
|
| 31 |
+
--output_dir ${output_dir}/decoder \
|
| 32 |
+
--moe_plugin disable \
|
| 33 |
+
--max_beam_width ${MAX_BEAM_WIDTH} \
|
| 34 |
+
--max_batch_size ${MAX_BATCH_SIZE} \
|
| 35 |
+
--max_seq_len 512 \
|
| 36 |
+
--max_input_len 4 \
|
| 37 |
+
--max_encoder_input_len 1024 \
|
| 38 |
+
--gemm_plugin ${INFERENCE_PRECISION} \
|
| 39 |
+
--remove_input_padding disable \
|
| 40 |
+
--paged_kv_cache disable \
|
| 41 |
+
--gpt_attention_plugin ${INFERENCE_PRECISION}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# FireRedASR-AED-L-TensorRT/
|
| 45 |
+
# βββ encoder.fp16.onnx
|
| 46 |
+
# βββ encoder.plan
|
| 47 |
+
# βββ tllm_checkpoint_float16
|
| 48 |
+
# βΒ Β βββ decoder
|
| 49 |
+
# βΒ Β βββ config.json
|
| 50 |
+
# βΒ Β βββ rank0.safetensors
|
| 51 |
+
# βββ trt_engine_float16
|
| 52 |
+
# βββ decoder
|
| 53 |
+
# βββ config.json
|
| 54 |
+
# βββ rank0.engine
|
tllm_checkpoint_float16/decoder/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "DecoderModel",
|
| 3 |
+
"dtype": "float16",
|
| 4 |
+
"logits_dtype": "float16",
|
| 5 |
+
"num_hidden_layers": 16,
|
| 6 |
+
"num_attention_heads": 20,
|
| 7 |
+
"hidden_size": 1280,
|
| 8 |
+
"norm_epsilon": 1e-05,
|
| 9 |
+
"vocab_size": 7832,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"use_parallel_embedding": false,
|
| 12 |
+
"embedding_sharding_dim": 0,
|
| 13 |
+
"max_position_embeddings": 5000,
|
| 14 |
+
"use_prompt_tuning": false,
|
| 15 |
+
"head_size": 64,
|
| 16 |
+
"has_position_embedding": true,
|
| 17 |
+
"layernorm_type": 0,
|
| 18 |
+
"has_attention_qkvo_bias": true,
|
| 19 |
+
"has_mlp_bias": true,
|
| 20 |
+
"has_model_final_layernorm": true,
|
| 21 |
+
"has_embedding_layernorm": false,
|
| 22 |
+
"has_embedding_scale": true,
|
| 23 |
+
"ffn_hidden_size": 5120,
|
| 24 |
+
"q_scaling": 1.0,
|
| 25 |
+
"layernorm_position": 0,
|
| 26 |
+
"relative_attention": false,
|
| 27 |
+
"max_distance": 0,
|
| 28 |
+
"num_buckets": 0,
|
| 29 |
+
"model_type": "whisper",
|
| 30 |
+
"rescale_before_lm_head": false,
|
| 31 |
+
"encoder_hidden_size": 1280,
|
| 32 |
+
"encoder_num_heads": 20,
|
| 33 |
+
"encoder_head_size": null,
|
| 34 |
+
"skip_cross_kv": false,
|
| 35 |
+
"quantization": {
|
| 36 |
+
"quant_algo": null
|
| 37 |
+
}
|
| 38 |
+
}
|
tllm_checkpoint_float16/decoder/rank0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fae4a3ce0ab15552d307ef960a579c25f479d490b65959cf4189e7a723463037
|
| 3 |
+
size 892578184
|