File size: 2,774 Bytes
3c50954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78224f0
 
 
3c50954
 
 
 
 
 
 
 
 
 
 
 
 
 
78224f0
3c50954
 
 
 
 
 
 
 
 
 
 
78224f0
3c50954
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse

from ax_common import WenetAXRunner


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file")
    parser.add_argument("--online", action="store_true")
    parser.add_argument("--config",
                        type=str,
                        default="pretrained/aishell_u2pp_conformer_exp/train.yaml",
                        help="yaml file in checkpoint path")
    parser.add_argument("--vocab",
                        type=str,
                        default="pretrained/aishell_u2pp_conformer_exp/units.txt",
                        help="pretrained units.txt")
    parser.add_argument("--encoder_online",
                        type=str,
                        default="axmodel/encoder_online/encoder_online.axmodel")
    parser.add_argument("--encoder_offline",
                        type=str,
                        default="axmodel/encoder_offline/encoder_offline.axmodel")
    parser.add_argument("--decoder",
                        type=str,
                        default="axmodel/decoder/decoder.axmodel")
    parser.add_argument("--offline_seq_len", type=int, default=1024)
    parser.add_argument("--online_seq_len", type=int, default=67)
    parser.add_argument("--decoder_len", type=int, default=32)
    parser.add_argument("--decoding_chunk_size", type=int, default=16)
    parser.add_argument("--num_decoding_left_chunks", type=int, default=5)
    parser.add_argument("--provider",
                        type=str,
                        default="AxEngineExecutionProvider")
    parser.add_argument("--mode",
                        choices=[
                            "ctc_greedy_search", "ctc_prefix_beam_search",
                            "attention_rescoring"
                        ],
                        default="ctc_prefix_beam_search",
                        help="decoding mode")
    return parser.parse_args()


def main():
    args = get_args()
    print(f"online: {args.online}")
    print(f"mode: {args.mode}")
    print(f"provider: {args.provider}")

    runner = WenetAXRunner(
        args.config,
        args.vocab,
        encoder_offline_path=args.encoder_offline,
        encoder_online_path=args.encoder_online,
        decoder_path=args.decoder,
        offline_seq_len=args.offline_seq_len,
        decoder_len=args.decoder_len,
        decoding_chunk_size=args.decoding_chunk_size,
        num_decoding_left_chunks=args.num_decoding_left_chunks,
        provider=args.provider,
    )
    result = runner.transcribe(args.input,
                               online=args.online,
                               mode=args.mode)
    print(f"ASR Result: {result}")


if __name__ == "__main__":
    main()