| | import numpy as np |
| | import torch |
| |
|
| |
|
| | class SegmentStreamingE2E(object): |
| | """SegmentStreamingE2E constructor. |
| | |
| | :param E2E e2e: E2E ASR object |
| | :param recog_args: arguments for "recognize" method of E2E |
| | """ |
| |
|
| | def __init__(self, e2e, recog_args, rnnlm=None): |
| | self._e2e = e2e |
| | self._recog_args = recog_args |
| | self._char_list = e2e.char_list |
| | self._rnnlm = rnnlm |
| |
|
| | self._e2e.eval() |
| |
|
| | self._blank_idx_in_char_list = -1 |
| | for idx in range(len(self._char_list)): |
| | if self._char_list[idx] == self._e2e.blank: |
| | self._blank_idx_in_char_list = idx |
| | break |
| |
|
| | self._subsampling_factor = np.prod(e2e.subsample) |
| | self._activates = 0 |
| | self._blank_dur = 0 |
| |
|
| | self._previous_input = [] |
| | self._previous_encoder_recurrent_state = None |
| | self._encoder_states = [] |
| | self._ctc_posteriors = [] |
| |
|
| | assert ( |
| | self._recog_args.batchsize <= 1 |
| | ), "SegmentStreamingE2E works only with batch size <= 1" |
| | assert ( |
| | "b" not in self._e2e.etype |
| | ), "SegmentStreamingE2E works only with uni-directional encoders" |
| |
|
| | def accept_input(self, x): |
| | """Call this method each time a new batch of input is available.""" |
| |
|
| | self._previous_input.extend(x) |
| | h, ilen = self._e2e.subsample_frames(x) |
| |
|
| | |
| | h, _, self._previous_encoder_recurrent_state = self._e2e.enc( |
| | h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state |
| | ) |
| | z = self._e2e.ctc.argmax(h).squeeze(0) |
| |
|
| | if self._activates == 0 and z[0] != self._blank_idx_in_char_list: |
| | self._activates = 1 |
| |
|
| | |
| | tail_len = self._subsampling_factor * ( |
| | self._recog_args.streaming_onset_margin + 1 |
| | ) |
| | h, ilen = self._e2e.subsample_frames( |
| | np.reshape( |
| | self._previous_input[-tail_len:], [-1, len(self._previous_input[0])] |
| | ) |
| | ) |
| | h, _, self._previous_encoder_recurrent_state = self._e2e.enc( |
| | h.unsqueeze(0), ilen, None |
| | ) |
| |
|
| | hyp = None |
| | if self._activates == 1: |
| | self._encoder_states.extend(h.squeeze(0)) |
| | self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0)) |
| |
|
| | if z[0] == self._blank_idx_in_char_list: |
| | self._blank_dur += 1 |
| | else: |
| | self._blank_dur = 0 |
| |
|
| | if self._blank_dur >= self._recog_args.streaming_min_blank_dur: |
| | seg_len = ( |
| | len(self._encoder_states) |
| | - self._blank_dur |
| | + self._recog_args.streaming_offset_margin |
| | ) |
| | if seg_len > 0: |
| | |
| | h = torch.cat(self._encoder_states[:seg_len], dim=0).view( |
| | -1, self._encoder_states[0].size(0) |
| | ) |
| | if self._recog_args.ctc_weight > 0.0: |
| | lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view( |
| | -1, self._ctc_posteriors[0].size(0) |
| | ) |
| | if self._recog_args.batchsize > 0: |
| | lpz = lpz.unsqueeze(0) |
| | normalize_score = False |
| | else: |
| | lpz = None |
| | normalize_score = True |
| |
|
| | if self._recog_args.batchsize == 0: |
| | hyp = self._e2e.dec.recognize_beam( |
| | h, lpz, self._recog_args, self._char_list, self._rnnlm |
| | ) |
| | else: |
| | hlens = torch.tensor([h.shape[0]]) |
| | hyp = self._e2e.dec.recognize_beam_batch( |
| | h.unsqueeze(0), |
| | hlens, |
| | lpz, |
| | self._recog_args, |
| | self._char_list, |
| | self._rnnlm, |
| | normalize_score=normalize_score, |
| | )[0] |
| |
|
| | self._activates = 0 |
| | self._blank_dur = 0 |
| |
|
| | tail_len = ( |
| | self._subsampling_factor |
| | * self._recog_args.streaming_onset_margin |
| | ) |
| | self._previous_input = self._previous_input[-tail_len:] |
| | self._encoder_states = [] |
| | self._ctc_posteriors = [] |
| |
|
| | return hyp |
| |
|