| | import torch |
| |
|
| |
|
| | |
| | |
| | class WindowStreamingE2E(object): |
| | """WindowStreamingE2E constructor. |
| | |
| | :param E2E e2e: E2E ASR object |
| | :param recog_args: arguments for "recognize" method of E2E |
| | """ |
| |
|
| | def __init__(self, e2e, recog_args, rnnlm=None): |
| | self._e2e = e2e |
| | self._recog_args = recog_args |
| | self._char_list = e2e.char_list |
| | self._rnnlm = rnnlm |
| |
|
| | self._e2e.eval() |
| |
|
| | self._offset = 0 |
| | self._previous_encoder_recurrent_state = None |
| | self._encoder_states = [] |
| | self._ctc_posteriors = [] |
| | self._last_recognition = None |
| |
|
| | assert ( |
| | self._recog_args.ctc_weight > 0.0 |
| | ), "WindowStreamingE2E works only with combined CTC and attention decoders." |
| |
|
| | def accept_input(self, x): |
| | """Call this method each time a new batch of input is available.""" |
| |
|
| | h, ilen = self._e2e.subsample_frames(x) |
| |
|
| | |
| | h, _, self._previous_encoder_recurrent_state = self._e2e.enc( |
| | h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state |
| | ) |
| | self._encoder_states.append(h.squeeze(0)) |
| |
|
| | |
| | self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0)) |
| |
|
| | def _input_window_for_decoder(self, use_all=False): |
| | if use_all: |
| | return ( |
| | torch.cat(self._encoder_states, dim=0), |
| | torch.cat(self._ctc_posteriors, dim=0), |
| | ) |
| |
|
| | def select_unprocessed_windows(window_tensors): |
| | last_offset = self._offset |
| | offset_traversed = 0 |
| | selected_windows = [] |
| | for es in window_tensors: |
| | if offset_traversed > last_offset: |
| | selected_windows.append(es) |
| | continue |
| | offset_traversed += es.size(1) |
| | return torch.cat(selected_windows, dim=0) |
| |
|
| | return ( |
| | select_unprocessed_windows(self._encoder_states), |
| | select_unprocessed_windows(self._ctc_posteriors), |
| | ) |
| |
|
| | def decode_with_attention_offline(self): |
| | """Run the attention decoder offline. |
| | |
| | Works even if the previous layers (encoder and CTC decoder) were |
| | being run in the online mode. |
| | This method should be run after all the audio has been consumed. |
| | This is used mostly to compare the results between offline |
| | and online implementation of the previous layers. |
| | """ |
| | h, lpz = self._input_window_for_decoder(use_all=True) |
| |
|
| | return self._e2e.dec.recognize_beam( |
| | h, lpz, self._recog_args, self._char_list, self._rnnlm |
| | ) |
| |
|