FireRedASR-AED / test_ax_model.py
inoryQwQ's picture
Optimize RTF and cer
d56de90
import argparse
import os
import time
import logging
from fireredasr_axmodel import FireRedASRAxModel
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger_stream_hander = logging.StreamHandler()
logger_stream_hander.setLevel("INFO")
logger.addHandler(logger_stream_hander)
def parse_args():
parser = argparse.ArgumentParser(description="FireRedASRAxModel Test")
parser.add_argument(
"--encoder",
type=str,
default="axmodel/encoder.axmodel",
help="Path to axmodel encoder",
)
parser.add_argument(
"--decoder_loop",
type=str,
default="axmodel/decoder_loop.axmodel",
help="Path to axmodel decoder loop",
)
parser.add_argument(
"--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn"
)
parser.add_argument(
"--dict", type=str, default="axmodel/dict.txt", help="Path to dict"
)
parser.add_argument(
"--spm_model",
type=str,
default="axmodel/train_bpe1000.model",
help="Path to spm model",
)
parser.add_argument(
"--wavlist", type=str, default="wavlist.txt", help="File to wav path list"
)
parser.add_argument(
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
)
parser.add_argument("--beam_size", type=int, default=1, help="")
parser.add_argument("--nbest", type=int, default=1, help="")
parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
return parser.parse_args()
def parse_wavlist(wavlist: str):
wavpaths = []
with open(wavlist) as f:
for line in f:
line = line.strip()
if not os.path.exists(line):
print(f"{line} doesn't exist.")
continue
wavpaths.append(line)
return wavpaths
def main():
args = parse_args()
print(args)
model = FireRedASRAxModel(
args.encoder,
args.decoder_loop,
args.cmvn,
args.dict,
args.spm_model,
decode_max_len=args.decode_max_len,
audio_dur=args.max_dur,
)
wf = open(args.hypo, "wt")
wavlist = parse_wavlist(args.wavlist)
total_wav_durations = 0
total_transcribe_durations = 0
for wav in wavlist:
batch_wav = [wav]
result, wav_durations, transcribe_durations = model.transcribe(
batch_wav, args.beam_size, args.nbest
)
wav_durations = sum(wav_durations)
total_wav_durations += wav_durations
total_transcribe_durations += transcribe_durations
logger.info(f"{batch_wav}")
logger.info(f"Durations: {wav_durations}")
logger.info(f"Transcribe Durations: {transcribe_durations}")
rtf = transcribe_durations / wav_durations
logger.info(f"(Real time factor) RTF: {rtf}")
text = result["text"]
logger.info(f"text: {text}")
logger.info("")
wf.write(f"{text}\n")
logger.info(f"total wav durations: {total_wav_durations}")
logger.info(f"total transcribe durations: {total_transcribe_durations}")
avg_ref = total_transcribe_durations / total_wav_durations
logger.info(f"AVG RTF: {avg_ref}")
wf.close()
if __name__ == "__main__":
main()