| import argparse |
| import os |
|
|
| from ort_common import WenetONNXRunner |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file") |
| parser.add_argument("--config", type=str, required=True, help="yaml file in checkpoint path") |
| parser.add_argument( |
| "--vocab", |
| type=str, |
| required=True, |
| help="pretrained units.txt, for example pretrained/<model>/units.txt", |
| ) |
| parser.add_argument("--online", action="store_true") |
| parser.add_argument("--offline_seq_len", type=int, default=1024) |
| parser.add_argument("--online_seq_len", type=int, default=67) |
| parser.add_argument("--decoder_len", type=int, default=32) |
| parser.add_argument("--onnx_dir", type=str, default="onnx_model") |
| parser.add_argument('--mode', |
| choices=[ |
| 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'], |
| default='ctc_prefix_beam_search', |
| help='decoding mode') |
| parser.add_argument("--calib_data_path", type=str, default="calibration_dataset", help="Generated calibration data path") |
| return parser.parse_args() |
|
|
|
|
| def run_ort(): |
| args = get_args() |
| print(f"online: {args.online}") |
| print(f"mode: {args.mode}") |
| print(f"calib_data_path: {args.calib_data_path}") |
| calib_data_path = args.calib_data_path or None |
| if calib_data_path: |
| os.makedirs(calib_data_path, exist_ok=True) |
|
|
| runner = WenetONNXRunner( |
| args.config, |
| args.vocab, |
| onnx_dir=args.onnx_dir, |
| offline_seq_len=args.offline_seq_len, |
| decoder_len=args.decoder_len, |
| ) |
| result = runner.transcribe(args.input, |
| online=args.online, |
| mode=args.mode, |
| calib_data_path=calib_data_path) |
|
|
| print(f"ASR Result: {result}") |
|
|
| if __name__ == "__main__": |
| run_ort() |
|
|