|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Deep speech decoder."""
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import itertools
|
|
|
| from nltk.metrics import distance
|
| import numpy as np
|
|
|
|
|
| class DeepSpeechDecoder(object):
|
| """Greedy decoder implementation for Deep Speech model."""
|
|
|
| def __init__(self, labels, blank_index=28):
|
| """Decoder initialization.
|
|
|
| Args:
|
| labels: a string specifying the speech labels for the decoder to use.
|
| blank_index: an integer specifying index for the blank character.
|
| Defaults to 28.
|
| """
|
|
|
| self.labels = labels
|
| self.blank_index = blank_index
|
| self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
|
|
|
| def convert_to_string(self, sequence):
|
| """Convert a sequence of indexes into corresponding string."""
|
| return ''.join([self.int_to_char[i] for i in sequence])
|
|
|
| def wer(self, decode, target):
|
| """Computes the Word Error Rate (WER).
|
|
|
| WER is defined as the edit distance between the two provided sentences after
|
| tokenizing to words.
|
|
|
| Args:
|
| decode: string of the decoded output.
|
| target: a string for the ground truth label.
|
|
|
| Returns:
|
| A float number for the WER of the current decode-target pair.
|
| """
|
|
|
| words = set(decode.split() + target.split())
|
| word2char = dict(zip(words, range(len(words))))
|
|
|
| new_decode = [chr(word2char[w]) for w in decode.split()]
|
| new_target = [chr(word2char[w]) for w in target.split()]
|
|
|
| return distance.edit_distance(''.join(new_decode), ''.join(new_target))
|
|
|
| def cer(self, decode, target):
|
| """Computes the Character Error Rate (CER).
|
|
|
| CER is defined as the edit distance between the two given strings.
|
|
|
| Args:
|
| decode: a string of the decoded output.
|
| target: a string for the ground truth label.
|
|
|
| Returns:
|
| A float number denoting the CER for the current sentence pair.
|
| """
|
| return distance.edit_distance(decode, target)
|
|
|
| def decode(self, logits):
|
| """Decode the best guess from logits using greedy algorithm."""
|
|
|
| best = list(np.argmax(logits, axis=1))
|
|
|
| merge = [k for k, _ in itertools.groupby(best)]
|
|
|
| merge_remove_blank = []
|
| for k in merge:
|
| if k != self.blank_index:
|
| merge_remove_blank.append(k)
|
|
|
| return self.convert_to_string(merge_remove_blank)
|
|
|