| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import List, Optional, Tuple |
| |
|
| | import k2 |
| | import torch |
| | from beam_search import Hypothesis, HypothesisList |
| |
|
| | from icefall.utils import AttributeDict |
| |
|
| |
|
| | class DecodeStream(object): |
| | def __init__( |
| | self, |
| | params: AttributeDict, |
| | cut_id: str, |
| | initial_states: List[torch.Tensor], |
| | decoding_graph: Optional[k2.Fsa] = None, |
| | device: torch.device = torch.device("cpu"), |
| | ) -> None: |
| | """ |
| | Args: |
| | initial_states: |
| | Initial decode states of the model, e.g. the return value of |
| | `get_init_state` in conformer.py |
| | decoding_graph: |
| | Decoding graph used for decoding, may be a TrivialGraph or a HLG. |
| | Used only when decoding_method is fast_beam_search. |
| | device: |
| | The device to run this stream. |
| | """ |
| | if params.decoding_method == "fast_beam_search": |
| | assert decoding_graph is not None |
| | assert device == decoding_graph.device |
| |
|
| | self.params = params |
| | self.cut_id = cut_id |
| | self.LOG_EPS = math.log(1e-10) |
| |
|
| | self.states = initial_states |
| |
|
| | |
| | self.features: torch.Tensor = None |
| |
|
| | self.num_frames: int = 0 |
| | |
| | |
| | self.num_processed_frames: int = 0 |
| |
|
| | self._done: bool = False |
| |
|
| | |
| | self.ground_truth: str = "" |
| |
|
| | |
| | self.hyp: List = [] |
| |
|
| | |
| | self.done_frames: int = 0 |
| |
|
| | |
| | |
| | self.pad_length = 7 + 2 * 3 |
| |
|
| | if params.decoding_method == "greedy_search": |
| | self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] |
| | elif params.decoding_method == "modified_beam_search": |
| | self.hyps = HypothesisList() |
| | self.hyps.add( |
| | Hypothesis( |
| | ys=[-1] * (params.context_size - 1) + [params.blank_id], |
| | log_prob=torch.zeros(1, dtype=torch.float32, device=device), |
| | ) |
| | ) |
| | elif params.decoding_method == "fast_beam_search": |
| | |
| | self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( |
| | decoding_graph |
| | ) |
| | else: |
| | raise ValueError(f"Unsupported decoding method: {params.decoding_method}") |
| |
|
| | @property |
| | def done(self) -> bool: |
| | """Return True if all the features are processed.""" |
| | return self._done |
| |
|
| | @property |
| | def id(self) -> str: |
| | return self.cut_id |
| |
|
| | def set_features( |
| | self, |
| | features: torch.Tensor, |
| | tail_pad_len: int = 0, |
| | ) -> None: |
| | """Set features tensor of current utterance.""" |
| | assert features.dim() == 2, features.dim() |
| | self.features = torch.nn.functional.pad( |
| | features, |
| | (0, 0, 0, self.pad_length + tail_pad_len), |
| | mode="constant", |
| | value=self.LOG_EPS, |
| | ) |
| | self.num_frames = self.features.size(0) |
| |
|
| | def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: |
| | """Consume chunk_size frames of features""" |
| | chunk_length = chunk_size + self.pad_length |
| |
|
| | ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) |
| |
|
| | ret_features = self.features[ |
| | self.num_processed_frames : self.num_processed_frames + ret_length |
| | ] |
| |
|
| | self.num_processed_frames += chunk_size |
| | if self.num_processed_frames >= self.num_frames: |
| | self._done = True |
| |
|
| | return ret_features, ret_length |
| |
|
| | def decoding_result(self) -> List[int]: |
| | """Obtain current decoding result.""" |
| | if self.params.decoding_method == "greedy_search": |
| | return self.hyp[self.params.context_size :] |
| | elif self.params.decoding_method == "modified_beam_search": |
| | best_hyp = self.hyps.get_most_probable(length_norm=True) |
| | return best_hyp.ys[self.params.context_size :] |
| | else: |
| | assert self.params.decoding_method == "fast_beam_search" |
| | return self.hyp |
| |
|