from collections import defaultdict import torch import numpy as np from scipy.special import logsumexp # log(p1 + p2) = logsumexp([log_p1, log_p2]) NINF = -1 * float('inf') DEFAULT_EMISSION_THRESHOLD = 0.01 def _reconstruct(labels, blank=0): new_labels = [] # merge same labels previous = None for l in labels: if l != previous: new_labels.append(l) previous = l # delete blank new_labels = [l for l in new_labels if l != blank] return new_labels def greedy_decode(emission_log_prob, blank=0, **kwargs): labels = np.argmax(emission_log_prob, axis=-1) labels = _reconstruct(labels, blank=blank) return labels def beam_search_decode(emission_log_prob, blank=0, **kwargs): beam_size = kwargs['beam_size'] emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD)) length, class_count = emission_log_prob.shape beams = [([], 0)] # (prefix, accumulated_log_prob) for t in range(length): new_beams = [] for prefix, accumulated_log_prob in beams: for c in range(class_count): log_prob = emission_log_prob[t, c] if log_prob < emission_threshold: continue new_prefix = prefix + [c] # log(p1 * p2) = log_p1 + log_p2 new_accu_log_prob = accumulated_log_prob + log_prob new_beams.append((new_prefix, new_accu_log_prob)) # sorted by accumulated_log_prob new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] # sum up beams to produce labels total_accu_log_prob = {} for prefix, accu_log_prob in beams: labels = tuple(_reconstruct(prefix, blank)) # log(p1 + p2) = logsumexp([log_p1, log_p2]) total_accu_log_prob[labels] = \ logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)]) labels_beams = [(list(labels), accu_log_prob) for labels, accu_log_prob in total_accu_log_prob.items()] labels_beams.sort(key=lambda x: x[1], reverse=True) labels = labels_beams[0][0] return labels def prefix_beam_decode(emission_log_prob, blank=0, **kwargs): beam_size = kwargs['beam_size'] emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD)) length, class_count = emission_log_prob.shape beams = [(tuple(), (0, NINF))] # (prefix, (blank_log_prob, non_blank_log_prob)) # initial of beams: (empty_str, (log(1.0), log(0.0))) for t in range(length): new_beams_dict = defaultdict(lambda: (NINF, NINF)) # log(0.0) = NINF for prefix, (lp_b, lp_nb) in beams: for c in range(class_count): log_prob = emission_log_prob[t, c] if log_prob < emission_threshold: continue end_t = prefix[-1] if prefix else None # if new_prefix == prefix new_lp_b, new_lp_nb = new_beams_dict[prefix] if c == blank: new_beams_dict[prefix] = ( logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]), new_lp_nb ) continue if c == end_t: new_beams_dict[prefix] = ( new_lp_b, logsumexp([new_lp_nb, lp_nb + log_prob]) ) # if new_prefix == prefix + (c,) new_prefix = prefix + (c,) new_lp_b, new_lp_nb = new_beams_dict[new_prefix] if c != end_t: new_beams_dict[new_prefix] = ( new_lp_b, logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob]) ) else: new_beams_dict[new_prefix] = ( new_lp_b, logsumexp([new_lp_nb, lp_b + log_prob]) ) # sorted by log(blank_prob + non_blank_prob) beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True) beams = beams[:beam_size] labels = list(beams[0][0]) return labels def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10): try: emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2)) except RuntimeError: emission_log_probs = np.transpose(log_probs.detach().numpy(), (1, 0, 2)) # size of emission_log_probs: (batch, length, class) decoders = { 'greedy': greedy_decode, 'beam_search': beam_search_decode, 'prefix_beam_search': prefix_beam_decode, } decoder = decoders[method] decoded_list = [] for emission_log_prob in emission_log_probs: decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size) if label2char: decoded = [label2char[l] for l in decoded] decoded_list.append(decoded) return decoded_list