File size: 1,899 Bytes
98c8d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
export PYTHONPATH=$PWD/:$PYTHONPATH

# model_path=pretrained_models/FireRedASR-AED-L
# python3 export_encoder_tensorrt.py \
#     --model-dir $model_path \
#     --tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
#     --trt-engine-file-name encoder.plan

TRT_ENGINE_OUTPUT_DIR=./FireRedASR-AED-L-TensorRT
python3 export_encoder_tensorrt.py \
    --onnx-model-path $TRT_ENGINE_OUTPUT_DIR/encoder.fp16.onnx \
    --tensorrt-model-dir $TRT_ENGINE_OUTPUT_DIR \
    --trt-engine-file-name encoder.plan


INFERENCE_PRECISION=float16
MAX_BEAM_WIDTH=4
MAX_BATCH_SIZE=64
checkpoint_dir=$TRT_ENGINE_OUTPUT_DIR/tllm_checkpoint_float16 
output_dir=$TRT_ENGINE_OUTPUT_DIR/trt_engine_${INFERENCE_PRECISION}

# model_path=pretrained_models/FireRedASR-AED-L/model.pth.tar
# python3 convert_checkpoint.py \
#                 --dtype ${INFERENCE_PRECISION} \
#                 --model_path $model_path \
#                 --output_dir $checkpoint_dir

trtllm-build  --checkpoint_dir ${checkpoint_dir}/decoder \
              --output_dir ${output_dir}/decoder \
              --moe_plugin disable \
              --max_beam_width ${MAX_BEAM_WIDTH} \
              --max_batch_size ${MAX_BATCH_SIZE} \
              --max_seq_len 512 \
              --max_input_len 4 \
              --max_encoder_input_len 1024 \
              --gemm_plugin ${INFERENCE_PRECISION} \
              --remove_input_padding disable \
              --paged_kv_cache disable \
              --gpt_attention_plugin ${INFERENCE_PRECISION}


# FireRedASR-AED-L-TensorRT/
# β”œβ”€β”€ encoder.fp16.onnx
# β”œβ”€β”€ encoder.plan
# β”œβ”€β”€ tllm_checkpoint_float16
# β”‚Β Β  └── decoder
# β”‚Β Β      β”œβ”€β”€ config.json
# β”‚Β Β      └── rank0.safetensors
# └── trt_engine_float16
#     └── decoder
#         β”œβ”€β”€ config.json
#         └── rank0.engine