| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import json |
|
|
| import git |
| from omegaconf import OmegaConf |
| from utils import cal_target_metadata_wer, cal_write_wer, run_asr_inference |
|
|
| from nemo.core.config import hydra_runner |
| from nemo.utils import logging |
|
|
|
|
| """ |
| This script serves as evaluator of ASR models |
| Usage: |
| python python asr_evaluator.py \ |
| engine.pretrained_name="stt_en_conformer_transducer_large" \ |
| engine.inference.mode="offline" \ |
| engine.test_ds.augmentor.noise.manifest_path=<manifest file for noise data> \ |
| ..... |
| |
| Check out parameters in ./conf/eval.yaml |
| """ |
|
|
|
|
| @hydra_runner(config_path="conf", config_name="eval.yaml") |
| def main(cfg): |
| report = {} |
| logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
| |
| if cfg.env.save_git_hash: |
| repo = git.Repo(search_parent_directories=True) |
| report['git_hash'] = repo.head.object.hexsha |
|
|
| |
| |
|
|
| |
| |
| |
| cfg.engine = run_asr_inference(cfg=cfg.engine) |
|
|
| |
| cfg, total_res, eval_metric = cal_write_wer(cfg) |
| report.update({"res": total_res}) |
|
|
| for target in cfg.analyst.metadata: |
| if cfg.analyst.metadata[target].enable: |
| occ_avg_wer = cal_target_metadata_wer( |
| manifest=cfg.analyst.metric_calculator.output_filename, |
| target=target, |
| meta_cfg=cfg.analyst.metadata[target], |
| eval_metric=eval_metric, |
| ) |
| report[target] = occ_avg_wer |
|
|
| config_engine = OmegaConf.to_object(cfg.engine) |
| report.update(config_engine) |
|
|
| config_metric_calculator = OmegaConf.to_object(cfg.analyst.metric_calculator) |
| report.update(config_metric_calculator) |
|
|
| pretty = json.dumps(report, indent=4) |
| res = "%.3f" % (report["res"][eval_metric] * 100) |
| logging.info(pretty) |
| logging.info(f"Overall {eval_metric} is {res} %") |
|
|
| |
| report_file = "report.json" |
| if "report_filename" in cfg.writer and cfg.writer.report_filename: |
| report_file = cfg.writer.report_filename |
|
|
| with open(report_file, "a") as fout: |
| json.dump(report, fout) |
| fout.write('\n') |
| fout.flush() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|