WeNet / run_ort.py
inoryQwQ's picture
First commit
3c50954
Raw
History Blame Contribute Delete
2 kB
import argparse
import os
from ort_common import WenetONNXRunner
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", type=str, required=True, help="Input audio file")
parser.add_argument("--config", type=str, required=True, help="yaml file in checkpoint path")
parser.add_argument(
"--vocab",
type=str,
required=True,
help="pretrained units.txt, for example pretrained/<model>/units.txt",
)
parser.add_argument("--online", action="store_true")
parser.add_argument("--offline_seq_len", type=int, default=1024)
parser.add_argument("--online_seq_len", type=int, default=67)
parser.add_argument("--decoder_len", type=int, default=32)
parser.add_argument("--onnx_dir", type=str, default="onnx_model")
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'],
default='ctc_prefix_beam_search',
help='decoding mode')
parser.add_argument("--calib_data_path", type=str, default="calibration_dataset", help="Generated calibration data path")
return parser.parse_args()
def run_ort():
args = get_args()
print(f"online: {args.online}")
print(f"mode: {args.mode}")
print(f"calib_data_path: {args.calib_data_path}")
calib_data_path = args.calib_data_path or None
if calib_data_path:
os.makedirs(calib_data_path, exist_ok=True)
runner = WenetONNXRunner(
args.config,
args.vocab,
onnx_dir=args.onnx_dir,
offline_seq_len=args.offline_seq_len,
decoder_len=args.decoder_len,
)
result = runner.transcribe(args.input,
online=args.online,
mode=args.mode,
calib_data_path=calib_data_path)
print(f"ASR Result: {result}")
if __name__ == "__main__":
run_ort()