| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This script is based on speech_to_text_eval.py and allows you to score the hypotheses |
| with sclite. A local installation from https://github.com/usnistgov/SCTK is required. |
| Hypotheses and references are first saved in trn format and are scored after applying a glm |
| file (if provided). |
| |
| # Usage |
| |
| python speech_to_text_sclite.py \ |
| --asr_model="<Path to ASR Model>" \ |
| --dataset="<Path to manifest file>" \ |
| --out_dir="<Path to output dir, should be unique per model evaluated>" \ |
| --sctk_dir="<Path to root directory where SCTK is installed>" \ |
| --glm="<OPTIONAL: Path to glm file>" \ |
| --batch_size=4 |
| |
| """ |
|
|
| import errno |
| import json |
| import os |
| import subprocess |
| from argparse import ArgumentParser |
|
|
| import torch |
|
|
| from nemo.collections.asr.models import ASRModel |
| from nemo.collections.asr.parts.utils.manifest_utils import read_manifest |
| from nemo.utils import logging |
|
|
| try: |
| from torch.cuda.amp import autocast |
| except ImportError: |
| from contextlib import contextmanager |
|
|
| @contextmanager |
| def autocast(enabled=None): |
| yield |
|
|
|
|
| def score_with_sctk(sctk_dir, ref_fname, hyp_fname, out_dir, glm=""): |
| sclite_path = os.path.join(sctk_dir, "bin", "sclite") |
| if not os.path.exists(sclite_path): |
| raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), sclite_path) |
| |
| if os.path.exists(glm): |
| rfilter_path = os.path.join(sctk_dir, "bin", "rfilter1") |
| if not os.path.exists(rfilter_path): |
| raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), rfilter_path) |
| hypglm = os.path.join(out_dir, os.path.basename(hyp_fname)) + ".glm" |
| rfilt_cmd = [rfilter_path] + [glm] |
| with open(hypglm, "w", encoding='utf-8') as hypf, open(hyp_fname, "r", encoding='utf-8') as hyp_in: |
| subprocess.run(rfilt_cmd, stdin=hyp_in, stdout=hypf) |
| refglm = os.path.join(out_dir, os.path.basename(ref_fname)) + ".glm" |
| with open(refglm, "w", encoding='utf-8') as reff, open(ref_fname, "r", encoding='utf-8') as ref_in: |
| subprocess.run(rfilt_cmd, stdin=ref_in, stdout=reff) |
| else: |
| refglm = ref_fname |
| hypglm = hyp_fname |
|
|
| _ = subprocess.check_output(f"{sclite_path} -h {hypglm} -r {refglm} -i wsj -o all", shell=True) |
|
|
|
|
| can_gpu = torch.cuda.is_available() |
|
|
|
|
| def get_utt_info(manifest_path): |
| info_list = [] |
| with open(manifest_path, "r", encoding='utf-8') as utt_f: |
| for line in utt_f: |
| utt = json.loads(line) |
| info_list.append(utt) |
|
|
| return info_list |
|
|
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--asr_model", type=str, default="QuartzNet15x5Base-En", required=False, help="Pass: 'QuartzNet15x5Base-En'", |
| ) |
| parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--out_dir", type=str, required=True, help="Destination dir for output files") |
| parser.add_argument("--sctk_dir", type=str, required=False, default="", help="Path to sctk root dir") |
| parser.add_argument("--glm", type=str, required=False, default="", help="Path to glm file") |
| args = parser.parse_args() |
| torch.set_grad_enabled(False) |
|
|
| if not os.path.exists(args.out_dir): |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| use_sctk = os.path.exists(args.sctk_dir) |
|
|
| if args.asr_model.endswith('.nemo'): |
| logging.info(f"Using local ASR model from {args.asr_model}") |
| asr_model = ASRModel.restore_from(restore_path=args.asr_model, map_location='cpu') |
| else: |
| logging.info(f"Using NGC cloud ASR model {args.asr_model}") |
| asr_model = ASRModel.from_pretrained(model_name=args.asr_model, map_location='cpu') |
|
|
| if can_gpu: |
| asr_model = asr_model.cuda() |
|
|
| asr_model.eval() |
|
|
| manifest_data = read_manifest(args.dataset) |
|
|
| references = [data['text'] for data in manifest_data] |
| audio_filepaths = [data['audio_filepath'] for data in manifest_data] |
|
|
| with autocast(): |
| hypotheses = asr_model.transcribe(audio_filepaths, batch_size=args.batch_size) |
|
|
| |
| if type(hypotheses) == tuple and len(hypotheses) == 2: |
| hypotheses = hypotheses[0] |
|
|
| info_list = get_utt_info(args.dataset) |
| hypfile = os.path.join(args.out_dir, "hyp.trn") |
| reffile = os.path.join(args.out_dir, "ref.trn") |
| with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f: |
| for i in range(len(hypotheses)): |
| utt_id = os.path.splitext(os.path.basename(info_list[i]['audio_filepath']))[0] |
| |
| hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n") |
| ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n") |
|
|
| if use_sctk: |
| score_with_sctk(args.sctk_dir, reffile, hypfile, args.out_dir, glm=args.glm) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|