Spaces:
Build error
Build error
File size: 5,123 Bytes
7ee7e3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from collections import defaultdict
import torch
import numpy as np
from scipy.special import logsumexp # log(p1 + p2) = logsumexp([log_p1, log_p2])
NINF = -1 * float('inf')
DEFAULT_EMISSION_THRESHOLD = 0.01
def _reconstruct(labels, blank=0):
new_labels = []
# merge same labels
previous = None
for l in labels:
if l != previous:
new_labels.append(l)
previous = l
# delete blank
new_labels = [l for l in new_labels if l != blank]
return new_labels
def greedy_decode(emission_log_prob, blank=0, **kwargs):
labels = np.argmax(emission_log_prob, axis=-1)
labels = _reconstruct(labels, blank=blank)
return labels
def beam_search_decode(emission_log_prob, blank=0, **kwargs):
beam_size = kwargs['beam_size']
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
length, class_count = emission_log_prob.shape
beams = [([], 0)] # (prefix, accumulated_log_prob)
for t in range(length):
new_beams = []
for prefix, accumulated_log_prob in beams:
for c in range(class_count):
log_prob = emission_log_prob[t, c]
if log_prob < emission_threshold:
continue
new_prefix = prefix + [c]
# log(p1 * p2) = log_p1 + log_p2
new_accu_log_prob = accumulated_log_prob + log_prob
new_beams.append((new_prefix, new_accu_log_prob))
# sorted by accumulated_log_prob
new_beams.sort(key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_size]
# sum up beams to produce labels
total_accu_log_prob = {}
for prefix, accu_log_prob in beams:
labels = tuple(_reconstruct(prefix, blank))
# log(p1 + p2) = logsumexp([log_p1, log_p2])
total_accu_log_prob[labels] = \
logsumexp([accu_log_prob, total_accu_log_prob.get(labels, NINF)])
labels_beams = [(list(labels), accu_log_prob)
for labels, accu_log_prob in total_accu_log_prob.items()]
labels_beams.sort(key=lambda x: x[1], reverse=True)
labels = labels_beams[0][0]
return labels
def prefix_beam_decode(emission_log_prob, blank=0, **kwargs):
beam_size = kwargs['beam_size']
emission_threshold = kwargs.get('emission_threshold', np.log(DEFAULT_EMISSION_THRESHOLD))
length, class_count = emission_log_prob.shape
beams = [(tuple(), (0, NINF))] # (prefix, (blank_log_prob, non_blank_log_prob))
# initial of beams: (empty_str, (log(1.0), log(0.0)))
for t in range(length):
new_beams_dict = defaultdict(lambda: (NINF, NINF)) # log(0.0) = NINF
for prefix, (lp_b, lp_nb) in beams:
for c in range(class_count):
log_prob = emission_log_prob[t, c]
if log_prob < emission_threshold:
continue
end_t = prefix[-1] if prefix else None
# if new_prefix == prefix
new_lp_b, new_lp_nb = new_beams_dict[prefix]
if c == blank:
new_beams_dict[prefix] = (
logsumexp([new_lp_b, lp_b + log_prob, lp_nb + log_prob]),
new_lp_nb
)
continue
if c == end_t:
new_beams_dict[prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_nb + log_prob])
)
# if new_prefix == prefix + (c,)
new_prefix = prefix + (c,)
new_lp_b, new_lp_nb = new_beams_dict[new_prefix]
if c != end_t:
new_beams_dict[new_prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_b + log_prob, lp_nb + log_prob])
)
else:
new_beams_dict[new_prefix] = (
new_lp_b,
logsumexp([new_lp_nb, lp_b + log_prob])
)
# sorted by log(blank_prob + non_blank_prob)
beams = sorted(new_beams_dict.items(), key=lambda x: logsumexp(x[1]), reverse=True)
beams = beams[:beam_size]
labels = list(beams[0][0])
return labels
def ctc_decode(log_probs, label2char=None, blank=0, method='beam_search', beam_size=10):
try:
emission_log_probs = np.transpose(log_probs.cpu().numpy(), (1, 0, 2))
except RuntimeError:
emission_log_probs = np.transpose(log_probs.detach().numpy(), (1, 0, 2))
# size of emission_log_probs: (batch, length, class)
decoders = {
'greedy': greedy_decode,
'beam_search': beam_search_decode,
'prefix_beam_search': prefix_beam_decode,
}
decoder = decoders[method]
decoded_list = []
for emission_log_prob in emission_log_probs:
decoded = decoder(emission_log_prob, blank=blank, beam_size=beam_size)
if label2char:
decoded = [label2char[l] for l in decoded]
decoded_list.append(decoded)
return decoded_list
|