Alimustoofaa's picture
first commit
7ee7e3a
from collections import defaultdict
import torch
import numpy as np
from scipy.special import logsumexp # log(p1 + p2) = logsumexp([log_p1, log_p2])
NINF = -1 * float('inf')
DEFAULT_EMISSION_THRESHOLD = 0.01
def _reconstruct(labels, blank=0):
new_labels = []
# merge same labels
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# delete blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def greedy_decode(emission_log_prob, blank=0, **kwargs):
labels = np.argmax(emission_log_prob, axis=-1)
labels = _reconstruct(labels, blank=blank)
return labels
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
beam_size = kwargs['beam_size']
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
length, class_count = emission_log_prob.shape
beams = [([], 0)] # (prefix, accumulated_log_prob)
for t in range(length):
new_beams = []
for prefix, accumulated_log_prob in beams:
for c in range(class_count):
log_prob = emission_log_prob[t, c]
if log_prob < emission_threshold:
continue
new_prefix = prefix + [c]
# log(p1 * p2) = log_p1 + log_p2
new_accu_log_prob = accumulated_log_prob + log_prob
new_beams.append((new_prefix, new_accu_log_prob))
# sorted by accumulated_log_prob
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# sum up beams to produce labels
total_accu_log_prob = {}
for prefix, accu_log_prob in beams:
labels = tuple(_reconstruct(prefix, blank))
# log(p1 + p2) = logsumexp([log_p1, log_p2])
total_accu_log_prob[labels] = \
logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
labels_beams = [(list(labels), accu_log_prob)
for labels, accu_log_prob in total_accu_log_prob.items()]
labels_beams.sort(key=lambda x: x[1], reverse=True)
labels = labels_beams[0][0]
return labels
def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
beam_size = kwargs['beam_size']
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
length, class_count = emission_log_prob.shape
beams = [(tuple(), (0, NINF))] # (prefix, (blank_log_prob, non_blank_log_prob))
# initial of beams: (empty_str, (log(1.0), log(0.0)))
for t in range(length):
new_beams_dict = defaultdict(lambda: (NINF, NINF)) # log(0.0) = NINF
for prefix, (lp_b, lp_nb) in beams:
for c in range(class_count):
log_prob = emission_log_prob[t, c]
if log_prob < emission_threshold:
continue
end_t = prefix[-1] if prefix else None
# if new_prefix == prefix
new_lp_b, new_lp_nb = new_beams_dict[prefix]
if c == blank:
new_beams_dict[prefix] = (
logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
new_lp_nb
)
continue
if c == end_t:
new_beams_dict[prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_nb + log_prob])
)
# if new_prefix == prefix + (c,)
new_prefix = prefix + (c,)
new_lp_b, new_lp_nb = new_beams_dict[new_prefix]
if c != end_t:
new_beams_dict[new_prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
)
else:
new_beams_dict[new_prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_b + log_prob])
)
# sorted by log(blank_prob + non_blank_prob)
beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
beams = beams[:beam_size]
labels = list(beams[0][0])
return labels
def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10):
try:
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
except RuntimeError:
emission_log_probs = np.transpose(log_probs.detach().numpy(), (1, 0, 2))
# size of emission_log_probs: (batch, length, class)
decoders = {
'greedy': greedy_decode,
'beam_search': beam_search_decode,
'prefix_beam_search': prefix_beam_decode,
}
decoder = decoders[method]
decoded_list = []
for emission_log_prob in emission_log_probs:
decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
if label2char:
decoded = [label2char[l] for l in decoded]
decoded_list.append(decoded)
return decoded_list