| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import warnings |
| from typing import List |
|
|
| import k2 |
| import torch |
| import torch.nn as nn |
| from beam_search import Hypothesis, HypothesisList, get_hyps_shape |
| from decode_stream import DecodeStream |
|
|
| from icefall.decode import one_best_decoding |
| from icefall.utils import get_texts |
|
|
|
|
| def greedy_search( |
| model: nn.Module, |
| encoder_out: torch.Tensor, |
| streams: List[DecodeStream], |
| blank_penalty: float = 0.0, |
| ) -> None: |
| """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. |
| |
| Args: |
| model: |
| The transducer model. |
| encoder_out: |
| Output from the encoder. Its shape is (N, T, C), where N >= 1. |
| streams: |
| A list of Stream objects. |
| """ |
| assert len(streams) == encoder_out.size(0) |
| assert encoder_out.ndim == 3 |
|
|
| blank_id = model.decoder.blank_id |
| context_size = model.decoder.context_size |
| device = model.device |
| T = encoder_out.size(1) |
|
|
| decoder_input = torch.tensor( |
| [stream.hyp[-context_size:] for stream in streams], |
| device=device, |
| dtype=torch.int64, |
| ) |
| |
| decoder_out = model.decoder(decoder_input, need_pad=False) |
| decoder_out = model.joiner.decoder_proj(decoder_out) |
|
|
| for t in range(T): |
| |
| current_encoder_out = encoder_out[:, t : t + 1, :] |
|
|
| logits = model.joiner( |
| current_encoder_out.unsqueeze(2), |
| decoder_out.unsqueeze(1), |
| project_input=False, |
| ) |
| |
| logits = logits.squeeze(1).squeeze(1) |
|
|
| if blank_penalty != 0.0: |
| logits[:, 0] -= blank_penalty |
|
|
| assert logits.ndim == 2, logits.shape |
| y = logits.argmax(dim=1).tolist() |
| emitted = False |
| for i, v in enumerate(y): |
| if v != blank_id: |
| streams[i].hyp.append(v) |
| emitted = True |
| if emitted: |
| |
| decoder_input = torch.tensor( |
| [stream.hyp[-context_size:] for stream in streams], |
| device=device, |
| dtype=torch.int64, |
| ) |
| decoder_out = model.decoder( |
| decoder_input, |
| need_pad=False, |
| ) |
| decoder_out = model.joiner.decoder_proj(decoder_out) |
|
|
|
|
| def modified_beam_search( |
| model: nn.Module, |
| encoder_out: torch.Tensor, |
| streams: List[DecodeStream], |
| num_active_paths: int = 4, |
| blank_penalty: float = 0.0, |
| ) -> None: |
| """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. |
| |
| Args: |
| model: |
| The RNN-T model. |
| encoder_out: |
| A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of |
| the encoder model. |
| streams: |
| A list of stream objects. |
| num_active_paths: |
| Number of active paths during the beam search. |
| """ |
| assert encoder_out.ndim == 3, encoder_out.shape |
| assert len(streams) == encoder_out.size(0) |
|
|
| blank_id = model.decoder.blank_id |
| context_size = model.decoder.context_size |
| device = next(model.parameters()).device |
| batch_size = len(streams) |
| T = encoder_out.size(1) |
|
|
| B = [stream.hyps for stream in streams] |
|
|
| for t in range(T): |
| current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) |
| |
|
|
| hyps_shape = get_hyps_shape(B).to(device) |
|
|
| A = [list(b) for b in B] |
| B = [HypothesisList() for _ in range(batch_size)] |
|
|
| ys_log_probs = torch.stack( |
| [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 |
| ) |
|
|
| decoder_input = torch.tensor( |
| [hyp.ys[-context_size:] for hyps in A for hyp in hyps], |
| device=device, |
| dtype=torch.int64, |
| ) |
|
|
| decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) |
| decoder_out = model.joiner.decoder_proj(decoder_out) |
| |
|
|
| |
| |
| current_encoder_out = torch.index_select( |
| current_encoder_out, |
| dim=0, |
| index=hyps_shape.row_ids(1).to(torch.int64), |
| ) |
|
|
| logits = model.joiner(current_encoder_out, decoder_out, project_input=False) |
| |
|
|
| logits = logits.squeeze(1).squeeze(1) |
|
|
| if blank_penalty != 0.0: |
| logits[:, 0] -= blank_penalty |
|
|
| log_probs = logits.log_softmax(dim=-1) |
|
|
| log_probs.add_(ys_log_probs) |
|
|
| vocab_size = log_probs.size(-1) |
|
|
| log_probs = log_probs.reshape(-1) |
|
|
| row_splits = hyps_shape.row_splits(1) * vocab_size |
| log_probs_shape = k2.ragged.create_ragged_shape2( |
| row_splits=row_splits, cached_tot_size=log_probs.numel() |
| ) |
| ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) |
|
|
| for i in range(batch_size): |
| topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) |
|
|
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| topk_hyp_indexes = (topk_indexes // vocab_size).tolist() |
| topk_token_indexes = (topk_indexes % vocab_size).tolist() |
|
|
| for k in range(len(topk_hyp_indexes)): |
| hyp_idx = topk_hyp_indexes[k] |
| hyp = A[i][hyp_idx] |
|
|
| new_ys = hyp.ys[:] |
| new_token = topk_token_indexes[k] |
| if new_token != blank_id: |
| new_ys.append(new_token) |
|
|
| new_log_prob = topk_log_probs[k] |
| new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) |
| B[i].add(new_hyp) |
|
|
| for i in range(batch_size): |
| streams[i].hyps = B[i] |
|
|
|
|
| def fast_beam_search_one_best( |
| model: nn.Module, |
| encoder_out: torch.Tensor, |
| processed_lens: torch.Tensor, |
| streams: List[DecodeStream], |
| beam: float, |
| max_states: int, |
| max_contexts: int, |
| blank_penalty: float = 0.0, |
| ) -> None: |
| """It limits the maximum number of symbols per frame to 1. |
| |
| A lattice is first generated by Fsa-based beam search, then we get the |
| recognition by applying shortest path on the lattice. |
| |
| Args: |
| model: |
| An instance of `Transducer`. |
| encoder_out: |
| A tensor of shape (N, T, C) from the encoder. |
| processed_lens: |
| A tensor of shape (N,) containing the number of processed frames |
| in `encoder_out` before padding. |
| streams: |
| A list of stream objects. |
| beam: |
| Beam value, similar to the beam used in Kaldi.. |
| max_states: |
| Max states per stream per frame. |
| max_contexts: |
| Max contexts pre stream per frame. |
| """ |
| assert encoder_out.ndim == 3 |
| B, T, C = encoder_out.shape |
| assert B == len(streams) |
|
|
| context_size = model.decoder.context_size |
| vocab_size = model.decoder.vocab_size |
|
|
| config = k2.RnntDecodingConfig( |
| vocab_size=vocab_size, |
| decoder_history_len=context_size, |
| beam=beam, |
| max_contexts=max_contexts, |
| max_states=max_states, |
| ) |
| individual_streams = [] |
| for i in range(B): |
| individual_streams.append(streams[i].rnnt_decoding_stream) |
| decoding_streams = k2.RnntDecodingStreams(individual_streams, config) |
|
|
| for t in range(T): |
| |
| |
| shape, contexts = decoding_streams.get_contexts() |
| |
| contexts = contexts.to(torch.int64) |
| |
| decoder_out = model.decoder(contexts, need_pad=False) |
| decoder_out = model.joiner.decoder_proj(decoder_out) |
| |
| |
| |
| current_encoder_out = torch.index_select( |
| encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) |
| ) |
| |
| logits = model.joiner( |
| current_encoder_out.unsqueeze(2), |
| decoder_out.unsqueeze(1), |
| project_input=False, |
| ) |
| logits = logits.squeeze(1).squeeze(1) |
|
|
| if blank_penalty != 0.0: |
| logits[:, 0] -= blank_penalty |
|
|
| log_probs = logits.log_softmax(dim=-1) |
| decoding_streams.advance(log_probs) |
|
|
| decoding_streams.terminate_and_flush_to_streams() |
|
|
| lattice = decoding_streams.format_output(processed_lens.tolist()) |
| best_path = one_best_decoding(lattice) |
| hyp_tokens = get_texts(best_path) |
|
|
| for i in range(B): |
| streams[i].hyp = hyp_tokens[i] |
|
|