Spaces:
Running
Running
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Implement unsupervised metric for decoding hyperparameter selection: | |
| $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$ | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| import editdistance | |
| logging.root.setLevel(logging.INFO) | |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def get_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-s", "--hypo", help="hypo transcription", required=True) | |
| parser.add_argument( | |
| "-r", "--reference", help="reference transcription", required=True | |
| ) | |
| return parser | |
| def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p): | |
| d_cnt = 0 | |
| w_cnt = 0 | |
| w_cnt_h = 0 | |
| for uid in hyp_uid_to_tra: | |
| ref = ref_uid_to_tra[uid].split() | |
| if g2p is not None: | |
| hyp = g2p(hyp_uid_to_tra[uid]) | |
| hyp = [p for p in hyp if p != "'" and p != " "] | |
| hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp] | |
| else: | |
| hyp = hyp_uid_to_tra[uid].split() | |
| d_cnt += editdistance.eval(ref, hyp) | |
| w_cnt += len(ref) | |
| w_cnt_h += len(hyp) | |
| wer = float(d_cnt) / w_cnt | |
| logger.debug( | |
| ( | |
| f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; " | |
| f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}" | |
| ) | |
| ) | |
| return wer | |
| def main(): | |
| args = get_parser().parse_args() | |
| errs = 0 | |
| count = 0 | |
| with open(args.hypo, "r") as hf, open(args.reference, "r") as rf: | |
| for h, r in zip(hf, rf): | |
| h = h.rstrip().split() | |
| r = r.rstrip().split() | |
| errs += editdistance.eval(r, h) | |
| count += len(r) | |
| logger.info(f"UER: {errs / count * 100:.2f}%") | |
| if __name__ == "__main__": | |
| main() | |
| def load_tra(tra_path): | |
| with open(tra_path, "r") as f: | |
| uid_to_tra = {} | |
| for line in f: | |
| uid, tra = line.split(None, 1) | |
| uid_to_tra[uid] = tra | |
| logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}") | |
| return uid_to_tra | |