| import argparse | |
| from ax_common import WenetAXRunner | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file") | |
| parser.add_argument("--online", action="store_true") | |
| parser.add_argument("--config", | |
| type=str, | |
| default="pretrained/aishell_u2pp_conformer_exp/train.yaml", | |
| help="yaml file in checkpoint path") | |
| parser.add_argument("--vocab", | |
| type=str, | |
| default="pretrained/aishell_u2pp_conformer_exp/units.txt", | |
| help="pretrained units.txt") | |
| parser.add_argument("--encoder_online", | |
| type=str, | |
| default="axmodel/encoder_online/encoder_online.axmodel") | |
| parser.add_argument("--encoder_offline", | |
| type=str, | |
| default="axmodel/encoder_offline/encoder_offline.axmodel") | |
| parser.add_argument("--decoder", | |
| type=str, | |
| default="axmodel/decoder/decoder.axmodel") | |
| parser.add_argument("--offline_seq_len", type=int, default=1024) | |
| parser.add_argument("--online_seq_len", type=int, default=67) | |
| parser.add_argument("--decoder_len", type=int, default=32) | |
| parser.add_argument("--decoding_chunk_size", type=int, default=16) | |
| parser.add_argument("--num_decoding_left_chunks", type=int, default=5) | |
| parser.add_argument("--provider", | |
| type=str, | |
| default="AxEngineExecutionProvider") | |
| parser.add_argument("--mode", | |
| choices=[ | |
| "ctc_greedy_search", "ctc_prefix_beam_search", | |
| "attention_rescoring" | |
| ], | |
| default="ctc_prefix_beam_search", | |
| help="decoding mode") | |
| return parser.parse_args() | |
| def main(): | |
| args = get_args() | |
| print(f"online: {args.online}") | |
| print(f"mode: {args.mode}") | |
| print(f"provider: {args.provider}") | |
| runner = WenetAXRunner( | |
| args.config, | |
| args.vocab, | |
| encoder_offline_path=args.encoder_offline, | |
| encoder_online_path=args.encoder_online, | |
| decoder_path=args.decoder, | |
| offline_seq_len=args.offline_seq_len, | |
| decoder_len=args.decoder_len, | |
| decoding_chunk_size=args.decoding_chunk_size, | |
| num_decoding_left_chunks=args.num_decoding_left_chunks, | |
| provider=args.provider, | |
| ) | |
| result = runner.transcribe(args.input, | |
| online=args.online, | |
| mode=args.mode) | |
| print(f"ASR Result: {result}") | |
| if __name__ == "__main__": | |
| main() | |