Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import re | |
| from collections import deque | |
| from enum import Enum | |
| import numpy as np | |
| """ | |
| Utility modules for computation of Word Error Rate, | |
| Alignments, as well as more granular metrics like | |
| deletion, insersion and substitutions. | |
| """ | |
| class Code(Enum): | |
| match = 1 | |
| substitution = 2 | |
| insertion = 3 | |
| deletion = 4 | |
| class Token(object): | |
| def __init__(self, lbl="", st=np.nan, en=np.nan): | |
| if np.isnan(st): | |
| self.label, self.start, self.end = "", 0.0, 0.0 | |
| else: | |
| self.label, self.start, self.end = lbl, st, en | |
| class AlignmentResult(object): | |
| def __init__(self, refs, hyps, codes, score): | |
| self.refs = refs # std::deque<int> | |
| self.hyps = hyps # std::deque<int> | |
| self.codes = codes # std::deque<Code> | |
| self.score = score # float | |
| def coordinate_to_offset(row, col, ncols): | |
| return int(row * ncols + col) | |
| def offset_to_row(offset, ncols): | |
| return int(offset / ncols) | |
| def offset_to_col(offset, ncols): | |
| return int(offset % ncols) | |
| def trimWhitespace(str): | |
| return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str))) | |
| def str2toks(str): | |
| pieces = trimWhitespace(str).split(" ") | |
| toks = [] | |
| for p in pieces: | |
| toks.append(Token(p, 0.0, 0.0)) | |
| return toks | |
| class EditDistance(object): | |
| def __init__(self, time_mediated): | |
| self.time_mediated_ = time_mediated | |
| self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> | |
| self.backtraces_ = ( | |
| np.nan | |
| ) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_; | |
| self.confusion_pairs_ = {} | |
| def cost(self, ref, hyp, code): | |
| if self.time_mediated_: | |
| if code == Code.match: | |
| return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) | |
| elif code == Code.insertion: | |
| return hyp.end - hyp.start | |
| elif code == Code.deletion: | |
| return ref.end - ref.start | |
| else: # substitution | |
| return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1 | |
| else: | |
| if code == Code.match: | |
| return 0 | |
| elif code == Code.insertion or code == Code.deletion: | |
| return 3 | |
| else: # substitution | |
| return 4 | |
| def get_result(self, refs, hyps): | |
| res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan) | |
| num_rows, num_cols = self.scores_.shape | |
| res.score = self.scores_[num_rows - 1, num_cols - 1] | |
| curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols) | |
| while curr_offset != 0: | |
| curr_row = offset_to_row(curr_offset, num_cols) | |
| curr_col = offset_to_col(curr_offset, num_cols) | |
| prev_offset = self.backtraces_[curr_row, curr_col] | |
| prev_row = offset_to_row(prev_offset, num_cols) | |
| prev_col = offset_to_col(prev_offset, num_cols) | |
| res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++ | |
| res.hyps.appendleft(curr_col - 1) | |
| if curr_row - 1 == prev_row and curr_col == prev_col: | |
| res.codes.appendleft(Code.deletion) | |
| elif curr_row == prev_row and curr_col - 1 == prev_col: | |
| res.codes.appendleft(Code.insertion) | |
| else: | |
| # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col) | |
| ref_str = refs[res.refs[0]].label | |
| hyp_str = hyps[res.hyps[0]].label | |
| if ref_str == hyp_str: | |
| res.codes.appendleft(Code.match) | |
| else: | |
| res.codes.appendleft(Code.substitution) | |
| confusion_pair = "%s -> %s" % (ref_str, hyp_str) | |
| if confusion_pair not in self.confusion_pairs_: | |
| self.confusion_pairs_[confusion_pair] = 1 | |
| else: | |
| self.confusion_pairs_[confusion_pair] += 1 | |
| curr_offset = prev_offset | |
| return res | |
| def align(self, refs, hyps): | |
| if len(refs) == 0 and len(hyps) == 0: | |
| return np.nan | |
| # NOTE: we're not resetting the values in these matrices because every value | |
| # will be overridden in the loop below. If this assumption doesn't hold, | |
| # be sure to set all entries in self.scores_ and self.backtraces_ to 0. | |
| self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1)) | |
| self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1)) | |
| num_rows, num_cols = self.scores_.shape | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| if i == 0 and j == 0: | |
| self.scores_[i, j] = 0.0 | |
| self.backtraces_[i, j] = 0 | |
| continue | |
| if i == 0: | |
| self.scores_[i, j] = self.scores_[i, j - 1] + self.cost( | |
| None, hyps[j - 1], Code.insertion | |
| ) | |
| self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols) | |
| continue | |
| if j == 0: | |
| self.scores_[i, j] = self.scores_[i - 1, j] + self.cost( | |
| refs[i - 1], None, Code.deletion | |
| ) | |
| self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols) | |
| continue | |
| # Below here both i and j are greater than 0 | |
| ref = refs[i - 1] | |
| hyp = hyps[j - 1] | |
| best_score = self.scores_[i - 1, j - 1] + ( | |
| self.cost(ref, hyp, Code.match) | |
| if (ref.label == hyp.label) | |
| else self.cost(ref, hyp, Code.substitution) | |
| ) | |
| prev_row = i - 1 | |
| prev_col = j - 1 | |
| ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion) | |
| if ins < best_score: | |
| best_score = ins | |
| prev_row = i | |
| prev_col = j - 1 | |
| delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion) | |
| if delt < best_score: | |
| best_score = delt | |
| prev_row = i - 1 | |
| prev_col = j | |
| self.scores_[i, j] = best_score | |
| self.backtraces_[i, j] = coordinate_to_offset( | |
| prev_row, prev_col, num_cols | |
| ) | |
| return self.get_result(refs, hyps) | |
| class WERTransformer(object): | |
| def __init__(self, hyp_str, ref_str, verbose=True): | |
| self.ed_ = EditDistance(False) | |
| self.id2oracle_errs_ = {} | |
| self.utts_ = 0 | |
| self.words_ = 0 | |
| self.insertions_ = 0 | |
| self.deletions_ = 0 | |
| self.substitutions_ = 0 | |
| self.process(["dummy_str", hyp_str, ref_str]) | |
| if verbose: | |
| print("'%s' vs '%s'" % (hyp_str, ref_str)) | |
| self.report_result() | |
| def process(self, input): # std::vector<std::string>&& input | |
| if len(input) < 3: | |
| print( | |
| "Input must be of the form <id> ... <hypo> <ref> , got ", | |
| len(input), | |
| " inputs:", | |
| ) | |
| return None | |
| # Align | |
| # std::vector<Token> hyps; | |
| # std::vector<Token> refs; | |
| hyps = str2toks(input[-2]) | |
| refs = str2toks(input[-1]) | |
| alignment = self.ed_.align(refs, hyps) | |
| if alignment is None: | |
| print("Alignment is null") | |
| return np.nan | |
| # Tally errors | |
| ins = 0 | |
| dels = 0 | |
| subs = 0 | |
| for code in alignment.codes: | |
| if code == Code.substitution: | |
| subs += 1 | |
| elif code == Code.insertion: | |
| ins += 1 | |
| elif code == Code.deletion: | |
| dels += 1 | |
| # Output | |
| row = input | |
| row.append(str(len(refs))) | |
| row.append(str(ins)) | |
| row.append(str(dels)) | |
| row.append(str(subs)) | |
| # print(row) | |
| # Accumulate | |
| kIdIndex = 0 | |
| kNBestSep = "/" | |
| pieces = input[kIdIndex].split(kNBestSep) | |
| if len(pieces) == 0: | |
| print( | |
| "Error splitting ", | |
| input[kIdIndex], | |
| " on '", | |
| kNBestSep, | |
| "', got empty list", | |
| ) | |
| return np.nan | |
| id = pieces[0] | |
| if id not in self.id2oracle_errs_: | |
| self.utts_ += 1 | |
| self.words_ += len(refs) | |
| self.insertions_ += ins | |
| self.deletions_ += dels | |
| self.substitutions_ += subs | |
| self.id2oracle_errs_[id] = [ins, dels, subs] | |
| else: | |
| curr_err = ins + dels + subs | |
| prev_err = np.sum(self.id2oracle_errs_[id]) | |
| if curr_err < prev_err: | |
| self.id2oracle_errs_[id] = [ins, dels, subs] | |
| return 0 | |
| def report_result(self): | |
| # print("---------- Summary ---------------") | |
| if self.words_ == 0: | |
| print("No words counted") | |
| return | |
| # 1-best | |
| best_wer = ( | |
| 100.0 | |
| * (self.insertions_ + self.deletions_ + self.substitutions_) | |
| / self.words_ | |
| ) | |
| print( | |
| "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, " | |
| "%0.2f%% dels, %0.2f%% subs)" | |
| % ( | |
| best_wer, | |
| self.utts_, | |
| self.words_, | |
| 100.0 * self.insertions_ / self.words_, | |
| 100.0 * self.deletions_ / self.words_, | |
| 100.0 * self.substitutions_ / self.words_, | |
| ) | |
| ) | |
| def wer(self): | |
| if self.words_ == 0: | |
| wer = np.nan | |
| else: | |
| wer = ( | |
| 100.0 | |
| * (self.insertions_ + self.deletions_ + self.substitutions_) | |
| / self.words_ | |
| ) | |
| return wer | |
| def stats(self): | |
| if self.words_ == 0: | |
| stats = {} | |
| else: | |
| wer = ( | |
| 100.0 | |
| * (self.insertions_ + self.deletions_ + self.substitutions_) | |
| / self.words_ | |
| ) | |
| stats = dict( | |
| { | |
| "wer": wer, | |
| "utts": self.utts_, | |
| "numwords": self.words_, | |
| "ins": self.insertions_, | |
| "dels": self.deletions_, | |
| "subs": self.substitutions_, | |
| "confusion_pairs": self.ed_.confusion_pairs_, | |
| } | |
| ) | |
| return stats | |
| def calc_wer(hyp_str, ref_str): | |
| t = WERTransformer(hyp_str, ref_str, verbose=0) | |
| return t.wer() | |
| def calc_wer_stats(hyp_str, ref_str): | |
| t = WERTransformer(hyp_str, ref_str, verbose=0) | |
| return t.stats() | |
| def get_wer_alignment_codes(hyp_str, ref_str): | |
| """ | |
| INPUT: hypothesis string, reference string | |
| OUTPUT: List of alignment codes (intermediate results from WER computation) | |
| """ | |
| t = WERTransformer(hyp_str, ref_str, verbose=0) | |
| return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes | |
| def merge_counts(x, y): | |
| # Merge two hashes which have 'counts' as their values | |
| # This can be used for example to merge confusion pair counts | |
| # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs']) | |
| for k, v in y.items(): | |
| if k not in x: | |
| x[k] = 0 | |
| x[k] += v | |
| return x | |