File size: 3,364 Bytes
f21b604 c6f1198 f21b604 5d4703e f21b604 c6f1198 f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e c6f1198 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 90f0b29 5d4703e c6f1198 5d4703e 90f0b29 d56de90 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e f21b604 5d4703e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import argparse
import os
import time
import logging
from fireredasr_axmodel import FireRedASRAxModel
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger_stream_hander = logging.StreamHandler()
logger_stream_hander.setLevel("INFO")
logger.addHandler(logger_stream_hander)
def parse_args():
parser = argparse.ArgumentParser(description="FireRedASRAxModel Test")
parser.add_argument(
"--encoder",
type=str,
default="axmodel/encoder.axmodel",
help="Path to axmodel encoder",
)
parser.add_argument(
"--decoder_loop",
type=str,
default="axmodel/decoder_loop.axmodel",
help="Path to axmodel decoder loop",
)
parser.add_argument(
"--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn"
)
parser.add_argument(
"--dict", type=str, default="axmodel/dict.txt", help="Path to dict"
)
parser.add_argument(
"--spm_model",
type=str,
default="axmodel/train_bpe1000.model",
help="Path to spm model",
)
parser.add_argument(
"--wavlist", type=str, default="wavlist.txt", help="File to wav path list"
)
parser.add_argument(
"--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos"
)
parser.add_argument("--beam_size", type=int, default=1, help="")
parser.add_argument("--nbest", type=int, default=1, help="")
parser.add_argument("--decode_max_len", type=int, default=128, help="max token len")
parser.add_argument("--max_dur", type=int, default=10, help="max audio len")
return parser.parse_args()
def parse_wavlist(wavlist: str):
wavpaths = []
with open(wavlist) as f:
for line in f:
line = line.strip()
if not os.path.exists(line):
print(f"{line} doesn't exist.")
continue
wavpaths.append(line)
return wavpaths
def main():
args = parse_args()
print(args)
model = FireRedASRAxModel(
args.encoder,
args.decoder_loop,
args.cmvn,
args.dict,
args.spm_model,
decode_max_len=args.decode_max_len,
audio_dur=args.max_dur,
)
wf = open(args.hypo, "wt")
wavlist = parse_wavlist(args.wavlist)
total_wav_durations = 0
total_transcribe_durations = 0
for wav in wavlist:
batch_wav = [wav]
result, wav_durations, transcribe_durations = model.transcribe(
batch_wav, args.beam_size, args.nbest
)
wav_durations = sum(wav_durations)
total_wav_durations += wav_durations
total_transcribe_durations += transcribe_durations
logger.info(f"{batch_wav}")
logger.info(f"Durations: {wav_durations}")
logger.info(f"Transcribe Durations: {transcribe_durations}")
rtf = transcribe_durations / wav_durations
logger.info(f"(Real time factor) RTF: {rtf}")
text = result["text"]
logger.info(f"text: {text}")
logger.info("")
wf.write(f"{text}\n")
logger.info(f"total wav durations: {total_wav_durations}")
logger.info(f"total transcribe durations: {total_transcribe_durations}")
avg_ref = total_transcribe_durations / total_wav_durations
logger.info(f"AVG RTF: {avg_ref}")
wf.close()
if __name__ == "__main__":
main()
|