| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import absolute_import, division, print_function, unicode_literals |
| |
|
| | import re |
| | from collections import deque |
| | from enum import Enum |
| |
|
| | import numpy as np |
| |
|
| |
|
| | """ |
| | Utility modules for computation of Word Error Rate, |
| | Alignments, as well as more granular metrics like |
| | deletion, insersion and substitutions. |
| | """ |
| |
|
| |
|
| | class Code(Enum): |
| | match = 1 |
| | substitution = 2 |
| | insertion = 3 |
| | deletion = 4 |
| |
|
| |
|
| | class Token(object): |
| | def __init__(self, lbl="", st=np.nan, en=np.nan): |
| | if np.isnan(st): |
| | self.label, self.start, self.end = "", 0.0, 0.0 |
| | else: |
| | self.label, self.start, self.end = lbl, st, en |
| |
|
| |
|
| | class AlignmentResult(object): |
| | def __init__(self, refs, hyps, codes, score): |
| | self.refs = refs |
| | self.hyps = hyps |
| | self.codes = codes |
| | self.score = score |
| |
|
| |
|
| | def coordinate_to_offset(row, col, ncols): |
| | return int(row * ncols + col) |
| |
|
| |
|
| | def offset_to_row(offset, ncols): |
| | return int(offset / ncols) |
| |
|
| |
|
| | def offset_to_col(offset, ncols): |
| | return int(offset % ncols) |
| |
|
| |
|
| | def trimWhitespace(str): |
| | return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str))) |
| |
|
| |
|
| | def str2toks(str): |
| | pieces = trimWhitespace(str).split(" ") |
| | toks = [] |
| | for p in pieces: |
| | toks.append(Token(p, 0.0, 0.0)) |
| | return toks |
| |
|
| |
|
| | class EditDistance(object): |
| | def __init__(self, time_mediated): |
| | self.time_mediated_ = time_mediated |
| | self.scores_ = np.nan |
| | self.backtraces_ = ( |
| | np.nan |
| | ) |
| | self.confusion_pairs_ = {} |
| |
|
| | def cost(self, ref, hyp, code): |
| | if self.time_mediated_: |
| | if code == Code.match: |
| | return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) |
| | elif code == Code.insertion: |
| | return hyp.end - hyp.start |
| | elif code == Code.deletion: |
| | return ref.end - ref.start |
| | else: |
| | return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1 |
| | else: |
| | if code == Code.match: |
| | return 0 |
| | elif code == Code.insertion or code == Code.deletion: |
| | return 3 |
| | else: |
| | return 4 |
| |
|
| | def get_result(self, refs, hyps): |
| | res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan) |
| |
|
| | num_rows, num_cols = self.scores_.shape |
| | res.score = self.scores_[num_rows - 1, num_cols - 1] |
| |
|
| | curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols) |
| |
|
| | while curr_offset != 0: |
| | curr_row = offset_to_row(curr_offset, num_cols) |
| | curr_col = offset_to_col(curr_offset, num_cols) |
| |
|
| | prev_offset = self.backtraces_[curr_row, curr_col] |
| |
|
| | prev_row = offset_to_row(prev_offset, num_cols) |
| | prev_col = offset_to_col(prev_offset, num_cols) |
| |
|
| | res.refs.appendleft(curr_row - 1) |
| | res.hyps.appendleft(curr_col - 1) |
| | if curr_row - 1 == prev_row and curr_col == prev_col: |
| | res.codes.appendleft(Code.deletion) |
| | elif curr_row == prev_row and curr_col - 1 == prev_col: |
| | res.codes.appendleft(Code.insertion) |
| | else: |
| | |
| | ref_str = refs[res.refs[0]].label |
| | hyp_str = hyps[res.hyps[0]].label |
| |
|
| | if ref_str == hyp_str: |
| | res.codes.appendleft(Code.match) |
| | else: |
| | res.codes.appendleft(Code.substitution) |
| |
|
| | confusion_pair = "%s -> %s" % (ref_str, hyp_str) |
| | if confusion_pair not in self.confusion_pairs_: |
| | self.confusion_pairs_[confusion_pair] = 1 |
| | else: |
| | self.confusion_pairs_[confusion_pair] += 1 |
| |
|
| | curr_offset = prev_offset |
| |
|
| | return res |
| |
|
| | def align(self, refs, hyps): |
| | if len(refs) == 0 and len(hyps) == 0: |
| | return np.nan |
| |
|
| | |
| | |
| | |
| | self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1)) |
| | self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1)) |
| |
|
| | num_rows, num_cols = self.scores_.shape |
| |
|
| | for i in range(num_rows): |
| | for j in range(num_cols): |
| | if i == 0 and j == 0: |
| | self.scores_[i, j] = 0.0 |
| | self.backtraces_[i, j] = 0 |
| | continue |
| |
|
| | if i == 0: |
| | self.scores_[i, j] = self.scores_[i, j - 1] + self.cost( |
| | None, hyps[j - 1], Code.insertion |
| | ) |
| | self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols) |
| | continue |
| |
|
| | if j == 0: |
| | self.scores_[i, j] = self.scores_[i - 1, j] + self.cost( |
| | refs[i - 1], None, Code.deletion |
| | ) |
| | self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols) |
| | continue |
| |
|
| | |
| | ref = refs[i - 1] |
| | hyp = hyps[j - 1] |
| | best_score = self.scores_[i - 1, j - 1] + ( |
| | self.cost(ref, hyp, Code.match) |
| | if (ref.label == hyp.label) |
| | else self.cost(ref, hyp, Code.substitution) |
| | ) |
| |
|
| | prev_row = i - 1 |
| | prev_col = j - 1 |
| | ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion) |
| | if ins < best_score: |
| | best_score = ins |
| | prev_row = i |
| | prev_col = j - 1 |
| |
|
| | delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion) |
| | if delt < best_score: |
| | best_score = delt |
| | prev_row = i - 1 |
| | prev_col = j |
| |
|
| | self.scores_[i, j] = best_score |
| | self.backtraces_[i, j] = coordinate_to_offset( |
| | prev_row, prev_col, num_cols |
| | ) |
| |
|
| | return self.get_result(refs, hyps) |
| |
|
| |
|
| | class WERTransformer(object): |
| | def __init__(self, hyp_str, ref_str, verbose=True): |
| | self.ed_ = EditDistance(False) |
| | self.id2oracle_errs_ = {} |
| | self.utts_ = 0 |
| | self.words_ = 0 |
| | self.insertions_ = 0 |
| | self.deletions_ = 0 |
| | self.substitutions_ = 0 |
| |
|
| | self.process(["dummy_str", hyp_str, ref_str]) |
| |
|
| | if verbose: |
| | print("'%s' vs '%s'" % (hyp_str, ref_str)) |
| | self.report_result() |
| |
|
| | def process(self, input): |
| | if len(input) < 3: |
| | print( |
| | "Input must be of the form <id> ... <hypo> <ref> , got ", |
| | len(input), |
| | " inputs:", |
| | ) |
| | return None |
| |
|
| | |
| | |
| | |
| |
|
| | hyps = str2toks(input[-2]) |
| | refs = str2toks(input[-1]) |
| |
|
| | alignment = self.ed_.align(refs, hyps) |
| | if alignment is None: |
| | print("Alignment is null") |
| | return np.nan |
| |
|
| | |
| | ins = 0 |
| | dels = 0 |
| | subs = 0 |
| | for code in alignment.codes: |
| | if code == Code.substitution: |
| | subs += 1 |
| | elif code == Code.insertion: |
| | ins += 1 |
| | elif code == Code.deletion: |
| | dels += 1 |
| |
|
| | |
| | row = input |
| | row.append(str(len(refs))) |
| | row.append(str(ins)) |
| | row.append(str(dels)) |
| | row.append(str(subs)) |
| | |
| |
|
| | |
| | kIdIndex = 0 |
| | kNBestSep = "/" |
| |
|
| | pieces = input[kIdIndex].split(kNBestSep) |
| |
|
| | if len(pieces) == 0: |
| | print( |
| | "Error splitting ", |
| | input[kIdIndex], |
| | " on '", |
| | kNBestSep, |
| | "', got empty list", |
| | ) |
| | return np.nan |
| |
|
| | id = pieces[0] |
| | if id not in self.id2oracle_errs_: |
| | self.utts_ += 1 |
| | self.words_ += len(refs) |
| | self.insertions_ += ins |
| | self.deletions_ += dels |
| | self.substitutions_ += subs |
| | self.id2oracle_errs_[id] = [ins, dels, subs] |
| | else: |
| | curr_err = ins + dels + subs |
| | prev_err = np.sum(self.id2oracle_errs_[id]) |
| | if curr_err < prev_err: |
| | self.id2oracle_errs_[id] = [ins, dels, subs] |
| |
|
| | return 0 |
| |
|
| | def report_result(self): |
| | |
| | if self.words_ == 0: |
| | print("No words counted") |
| | return |
| |
|
| | |
| | best_wer = ( |
| | 100.0 |
| | * (self.insertions_ + self.deletions_ + self.substitutions_) |
| | / self.words_ |
| | ) |
| |
|
| | print( |
| | "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, " |
| | "%0.2f%% dels, %0.2f%% subs)" |
| | % ( |
| | best_wer, |
| | self.utts_, |
| | self.words_, |
| | 100.0 * self.insertions_ / self.words_, |
| | 100.0 * self.deletions_ / self.words_, |
| | 100.0 * self.substitutions_ / self.words_, |
| | ) |
| | ) |
| |
|
| | def wer(self): |
| | if self.words_ == 0: |
| | wer = np.nan |
| | else: |
| | wer = ( |
| | 100.0 |
| | * (self.insertions_ + self.deletions_ + self.substitutions_) |
| | / self.words_ |
| | ) |
| | return wer |
| |
|
| | def stats(self): |
| | if self.words_ == 0: |
| | stats = {} |
| | else: |
| | wer = ( |
| | 100.0 |
| | * (self.insertions_ + self.deletions_ + self.substitutions_) |
| | / self.words_ |
| | ) |
| | stats = dict( |
| | { |
| | "wer": wer, |
| | "utts": self.utts_, |
| | "numwords": self.words_, |
| | "ins": self.insertions_, |
| | "dels": self.deletions_, |
| | "subs": self.substitutions_, |
| | "confusion_pairs": self.ed_.confusion_pairs_, |
| | } |
| | ) |
| | return stats |
| |
|
| |
|
| | def calc_wer(hyp_str, ref_str): |
| | t = WERTransformer(hyp_str, ref_str, verbose=0) |
| | return t.wer() |
| |
|
| |
|
| | def calc_wer_stats(hyp_str, ref_str): |
| | t = WERTransformer(hyp_str, ref_str, verbose=0) |
| | return t.stats() |
| |
|
| |
|
| | def get_wer_alignment_codes(hyp_str, ref_str): |
| | """ |
| | INPUT: hypothesis string, reference string |
| | OUTPUT: List of alignment codes (intermediate results from WER computation) |
| | """ |
| | t = WERTransformer(hyp_str, ref_str, verbose=0) |
| | return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes |
| |
|
| |
|
| | def merge_counts(x, y): |
| | |
| | |
| | |
| | for k, v in y.items(): |
| | if k not in x: |
| | x[k] = 0 |
| | x[k] += v |
| | return x |
| |
|