Spaces:
Build error
Build error
| from collections import defaultdict | |
| import torch | |
| import numpy as np | |
| from scipy.special import logsumexp # log(p1 + p2) = logsumexp([log_p1, log_p2]) | |
| NINF = -1 * float('inf') | |
| DEFAULT_EMISSION_THRESHOLD = 0.01 | |
| def _reconstruct(labels, blank=0): | |
| new_labels = [] | |
| # merge same labels | |
| previous = None | |
| for l in labels: | |
| if l != previous: | |
| new_labels.append(l) | |
| previous = l | |
| # delete blank | |
| new_labels = [l for l in new_labels if l != blank] | |
| return new_labels | |
| def greedy_decode(emission_log_prob, blank=0, **kwargs): | |
| labels = np.argmax(emission_log_prob, axis=-1) | |
| labels = _reconstruct(labels, blank=blank) | |
| return labels | |
| def beam_search_decode(emission_log_prob, blank=0, **kwargs): | |
| beam_size = kwargs['beam_size'] | |
| emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD)) | |
| length, class_count = emission_log_prob.shape | |
| beams = [([], 0)] # (prefix, accumulated_log_prob) | |
| for t in range(length): | |
| new_beams = [] | |
| for prefix, accumulated_log_prob in beams: | |
| for c in range(class_count): | |
| log_prob = emission_log_prob[t, c] | |
| if log_prob < emission_threshold: | |
| continue | |
| new_prefix = prefix + [c] | |
| # log(p1 * p2) = log_p1 + log_p2 | |
| new_accu_log_prob = accumulated_log_prob + log_prob | |
| new_beams.append((new_prefix, new_accu_log_prob)) | |
| # sorted by accumulated_log_prob | |
| new_beams.sort(key=lambda x: x[1], reverse=True) | |
| beams = new_beams[:beam_size] | |
| # sum up beams to produce labels | |
| total_accu_log_prob = {} | |
| for prefix, accu_log_prob in beams: | |
| labels = tuple(_reconstruct(prefix, blank)) | |
| # log(p1 + p2) = logsumexp([log_p1, log_p2]) | |
| total_accu_log_prob[labels] = \ | |
| logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)]) | |
| labels_beams = [(list(labels), accu_log_prob) | |
| for labels, accu_log_prob in total_accu_log_prob.items()] | |
| labels_beams.sort(key=lambda x: x[1], reverse=True) | |
| labels = labels_beams[0][0] | |
| return labels | |
| def prefix_beam_decode(emission_log_prob, blank=0, **kwargs): | |
| beam_size = kwargs['beam_size'] | |
| emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD)) | |
| length, class_count = emission_log_prob.shape | |
| beams = [(tuple(), (0, NINF))] # (prefix, (blank_log_prob, non_blank_log_prob)) | |
| # initial of beams: (empty_str, (log(1.0), log(0.0))) | |
| for t in range(length): | |
| new_beams_dict = defaultdict(lambda: (NINF, NINF)) # log(0.0) = NINF | |
| for prefix, (lp_b, lp_nb) in beams: | |
| for c in range(class_count): | |
| log_prob = emission_log_prob[t, c] | |
| if log_prob < emission_threshold: | |
| continue | |
| end_t = prefix[-1] if prefix else None | |
| # if new_prefix == prefix | |
| new_lp_b, new_lp_nb = new_beams_dict[prefix] | |
| if c == blank: | |
| new_beams_dict[prefix] = ( | |
| logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]), | |
| new_lp_nb | |
| ) | |
| continue | |
| if c == end_t: | |
| new_beams_dict[prefix] = ( | |
| new_lp_b, | |
| logsumexp([new_lp_nb, lp_nb + log_prob]) | |
| ) | |
| # if new_prefix == prefix + (c,) | |
| new_prefix = prefix + (c,) | |
| new_lp_b, new_lp_nb = new_beams_dict[new_prefix] | |
| if c != end_t: | |
| new_beams_dict[new_prefix] = ( | |
| new_lp_b, | |
| logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob]) | |
| ) | |
| else: | |
| new_beams_dict[new_prefix] = ( | |
| new_lp_b, | |
| logsumexp([new_lp_nb, lp_b + log_prob]) | |
| ) | |
| # sorted by log(blank_prob + non_blank_prob) | |
| beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True) | |
| beams = beams[:beam_size] | |
| labels = list(beams[0][0]) | |
| return labels | |
| def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10): | |
| try: | |
| emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2)) | |
| except RuntimeError: | |
| emission_log_probs = np.transpose(log_probs.detach().numpy(), (1, 0, 2)) | |
| # size of emission_log_probs: (batch, length, class) | |
| decoders = { | |
| 'greedy': greedy_decode, | |
| 'beam_search': beam_search_decode, | |
| 'prefix_beam_search': prefix_beam_decode, | |
| } | |
| decoder = decoders[method] | |
| decoded_list = [] | |
| for emission_log_prob in emission_log_probs: | |
| decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size) | |
| if label2char: | |
| decoded = [label2char[l] for l in decoded] | |
| decoded_list.append(decoded) | |
| return decoded_list | |