| import argparse | |
| import os | |
| import time | |
| import logging | |
| from fireredasr_axmodel import FireRedASRAxModel | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| logger_stream_hander = logging.StreamHandler() | |
| logger_stream_hander.setLevel("INFO") | |
| logger.addHandler(logger_stream_hander) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="FireRedASRAxModel Test") | |
| parser.add_argument( | |
| "--encoder", | |
| type=str, | |
| default="axmodel/encoder.axmodel", | |
| help="Path to axmodel encoder", | |
| ) | |
| parser.add_argument( | |
| "--decoder_loop", | |
| type=str, | |
| default="axmodel/decoder_loop.axmodel", | |
| help="Path to axmodel decoder loop", | |
| ) | |
| parser.add_argument( | |
| "--cmvn", type=str, default="axmodel/cmvn.ark", help="Path to cmvn" | |
| ) | |
| parser.add_argument( | |
| "--dict", type=str, default="axmodel/dict.txt", help="Path to dict" | |
| ) | |
| parser.add_argument( | |
| "--spm_model", | |
| type=str, | |
| default="axmodel/train_bpe1000.model", | |
| help="Path to spm model", | |
| ) | |
| parser.add_argument( | |
| "--wavlist", type=str, default="wavlist.txt", help="File to wav path list" | |
| ) | |
| parser.add_argument( | |
| "--hypo", type=str, default="hypo_axmodel.txt", help="File of hypos" | |
| ) | |
| parser.add_argument("--beam_size", type=int, default=1, help="") | |
| parser.add_argument("--nbest", type=int, default=1, help="") | |
| parser.add_argument("--decode_max_len", type=int, default=128, help="max token len") | |
| parser.add_argument("--max_dur", type=int, default=10, help="max audio len") | |
| return parser.parse_args() | |
| def parse_wavlist(wavlist: str): | |
| wavpaths = [] | |
| with open(wavlist) as f: | |
| for line in f: | |
| line = line.strip() | |
| if not os.path.exists(line): | |
| print(f"{line} doesn't exist.") | |
| continue | |
| wavpaths.append(line) | |
| return wavpaths | |
| def main(): | |
| args = parse_args() | |
| print(args) | |
| model = FireRedASRAxModel( | |
| args.encoder, | |
| args.decoder_loop, | |
| args.cmvn, | |
| args.dict, | |
| args.spm_model, | |
| decode_max_len=args.decode_max_len, | |
| audio_dur=args.max_dur, | |
| ) | |
| wf = open(args.hypo, "wt") | |
| wavlist = parse_wavlist(args.wavlist) | |
| total_wav_durations = 0 | |
| total_transcribe_durations = 0 | |
| for wav in wavlist: | |
| batch_wav = [wav] | |
| result, wav_durations, transcribe_durations = model.transcribe( | |
| batch_wav, args.beam_size, args.nbest | |
| ) | |
| wav_durations = sum(wav_durations) | |
| total_wav_durations += wav_durations | |
| total_transcribe_durations += transcribe_durations | |
| logger.info(f"{batch_wav}") | |
| logger.info(f"Durations: {wav_durations}") | |
| logger.info(f"Transcribe Durations: {transcribe_durations}") | |
| rtf = transcribe_durations / wav_durations | |
| logger.info(f"(Real time factor) RTF: {rtf}") | |
| text = result["text"] | |
| logger.info(f"text: {text}") | |
| logger.info("") | |
| wf.write(f"{text}\n") | |
| logger.info(f"total wav durations: {total_wav_durations}") | |
| logger.info(f"total transcribe durations: {total_transcribe_durations}") | |
| avg_ref = total_transcribe_durations / total_wav_durations | |
| logger.info(f"AVG RTF: {avg_ref}") | |
| wf.close() | |
| if __name__ == "__main__": | |
| main() | |