File size: 2,000 Bytes
3c50954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import os

from ort_common import WenetONNXRunner


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file")
    parser.add_argument("--config", type=str, required=True, help="yaml file in checkpoint path")
    parser.add_argument(
        "--vocab",
        type=str,
        required=True,
        help="pretrained units.txt, for example pretrained/<model>/units.txt",
    )
    parser.add_argument("--online", action="store_true")
    parser.add_argument("--offline_seq_len", type=int, default=1024)
    parser.add_argument("--online_seq_len", type=int, default=67)
    parser.add_argument("--decoder_len", type=int, default=32)
    parser.add_argument("--onnx_dir", type=str, default="onnx_model")
    parser.add_argument('--mode',
                        choices=[
                            'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'],
                        default='ctc_prefix_beam_search',
                        help='decoding mode')
    parser.add_argument("--calib_data_path", type=str, default="calibration_dataset", help="Generated calibration data path")
    return parser.parse_args()


def run_ort():
    args = get_args()
    print(f"online: {args.online}")
    print(f"mode: {args.mode}")
    print(f"calib_data_path: {args.calib_data_path}")
    calib_data_path = args.calib_data_path or None
    if calib_data_path:
        os.makedirs(calib_data_path, exist_ok=True)

    runner = WenetONNXRunner(
        args.config,
        args.vocab,
        onnx_dir=args.onnx_dir,
        offline_seq_len=args.offline_seq_len,
        decoder_len=args.decoder_len,
    )
    result = runner.transcribe(args.input,
                               online=args.online,
                               mode=args.mode,
                               calib_data_path=calib_data_path)

    print(f"ASR Result: {result}")

if __name__ == "__main__":
    run_ort()