File size: 2,774 Bytes
3c50954 78224f0 3c50954 78224f0 3c50954 78224f0 3c50954 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import argparse
from ax_common import WenetAXRunner
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file")
parser.add_argument("--online", action="store_true")
parser.add_argument("--config",
type=str,
default="pretrained/aishell_u2pp_conformer_exp/train.yaml",
help="yaml file in checkpoint path")
parser.add_argument("--vocab",
type=str,
default="pretrained/aishell_u2pp_conformer_exp/units.txt",
help="pretrained units.txt")
parser.add_argument("--encoder_online",
type=str,
default="axmodel/encoder_online/encoder_online.axmodel")
parser.add_argument("--encoder_offline",
type=str,
default="axmodel/encoder_offline/encoder_offline.axmodel")
parser.add_argument("--decoder",
type=str,
default="axmodel/decoder/decoder.axmodel")
parser.add_argument("--offline_seq_len", type=int, default=1024)
parser.add_argument("--online_seq_len", type=int, default=67)
parser.add_argument("--decoder_len", type=int, default=32)
parser.add_argument("--decoding_chunk_size", type=int, default=16)
parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
parser.add_argument("--provider",
type=str,
default="AxEngineExecutionProvider")
parser.add_argument("--mode",
choices=[
"ctc_greedy_search", "ctc_prefix_beam_search",
"attention_rescoring"
],
default="ctc_prefix_beam_search",
help="decoding mode")
return parser.parse_args()
def main():
args = get_args()
print(f"online: {args.online}")
print(f"mode: {args.mode}")
print(f"provider: {args.provider}")
runner = WenetAXRunner(
args.config,
args.vocab,
encoder_offline_path=args.encoder_offline,
encoder_online_path=args.encoder_online,
decoder_path=args.decoder,
offline_seq_len=args.offline_seq_len,
decoder_len=args.decoder_len,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
provider=args.provider,
)
result = runner.transcribe(args.input,
online=args.online,
mode=args.mode)
print(f"ASR Result: {result}")
if __name__ == "__main__":
main()
|