| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Usage: |
| | (1) greedy search |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method greedy_search |
| | |
| | (2) beam search (not recommended) |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method beam_search \ |
| | --beam-size 4 |
| | |
| | (3) modified beam search |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method modified_beam_search \ |
| | --beam-size 4 |
| | |
| | (4) fast beam search (one best) |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method fast_beam_search \ |
| | --beam 20.0 \ |
| | --max-contexts 8 \ |
| | --max-states 64 |
| | |
| | (5) fast beam search (nbest) |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method fast_beam_search_nbest \ |
| | --beam 20.0 \ |
| | --max-contexts 8 \ |
| | --max-states 64 \ |
| | --num-paths 200 \ |
| | --nbest-scale 0.5 |
| | |
| | (6) fast beam search (nbest oracle WER) |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method fast_beam_search_nbest_oracle \ |
| | --beam 20.0 \ |
| | --max-contexts 8 \ |
| | --max-states 64 \ |
| | --num-paths 200 \ |
| | --nbest-scale 0.5 |
| | |
| | (7) fast beam search (with LG) |
| | ./zipformer/decode.py \ |
| | --epoch 28 \ |
| | --avg 15 \ |
| | --exp-dir ./zipformer/exp \ |
| | --max-duration 600 \ |
| | --decoding-method fast_beam_search_nbest_LG \ |
| | --beam 20.0 \ |
| | --max-contexts 8 \ |
| | --max-states 64 |
| | """ |
| |
|
| |
|
| | import argparse |
| | import logging |
| | import math |
| | import os |
| | from collections import defaultdict |
| | from pathlib import Path |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | import k2 |
| | import sentencepiece as spm |
| | import torch |
| | import torch.nn as nn |
| | from asr_datamodule import LibriSpeechAsrDataModule |
| | from beam_search import ( |
| | beam_search, |
| | fast_beam_search_nbest, |
| | fast_beam_search_nbest_LG, |
| | fast_beam_search_nbest_oracle, |
| | fast_beam_search_one_best, |
| | greedy_search, |
| | greedy_search_batch, |
| | modified_beam_search, |
| | modified_beam_search_lm_rescore, |
| | modified_beam_search_lm_rescore_LODR, |
| | modified_beam_search_lm_shallow_fusion, |
| | modified_beam_search_LODR, |
| | ) |
| | from lhotse import set_caching_enabled |
| | from train import add_model_arguments, get_model, get_params |
| |
|
| | from icefall import ContextGraph, LmScorer, NgramLm |
| | from icefall.checkpoint import ( |
| | average_checkpoints, |
| | average_checkpoints_with_averaged_model, |
| | find_checkpoints, |
| | load_checkpoint, |
| | ) |
| | from icefall.lexicon import Lexicon |
| | from icefall.utils import ( |
| | AttributeDict, |
| | setup_logger, |
| | store_transcripts, |
| | str2bool, |
| | write_error_stats, |
| | ) |
| |
|
| | LOG_EPS = math.log(1e-10) |
| |
|
| |
|
| | def get_parser(): |
| | parser = argparse.ArgumentParser( |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| | ) |
| |
|
| | parser.add_argument( |
| | "--epoch", |
| | type=int, |
| | default=30, |
| | help="""It specifies the checkpoint to use for decoding. |
| | Note: Epoch counts from 1. |
| | You can specify --avg to use more checkpoints for model averaging.""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--iter", |
| | type=int, |
| | default=0, |
| | help="""If positive, --epoch is ignored and it |
| | will use the checkpoint exp_dir/checkpoint-iter.pt. |
| | You can specify --avg to use more checkpoints for model averaging. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--avg", |
| | type=int, |
| | default=15, |
| | help="Number of checkpoints to average. Automatically select " |
| | "consecutive checkpoints before the checkpoint specified by " |
| | "'--epoch' and '--iter'", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--use-averaged-model", |
| | type=str2bool, |
| | default=True, |
| | help="Whether to load averaged model. Currently it only supports " |
| | "using --epoch. If True, it would decode with the averaged model " |
| | "over the epoch range from `epoch-avg` (excluded) to `epoch`." |
| | "Actually only the models with epoch number of `epoch-avg` and " |
| | "`epoch` are loaded for averaging. ", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--exp-dir", |
| | type=str, |
| | default="zipformer/exp", |
| | help="The experiment dir", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--bpe-model", |
| | type=str, |
| | default="data/lang_bpe_500/bpe.model", |
| | help="Path to the BPE model", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--lang-dir", |
| | type=Path, |
| | default="data/lang_bpe_500", |
| | help="The lang dir containing word table and LG graph", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--decoding-method", |
| | type=str, |
| | default="greedy_search", |
| | help="""Possible values are: |
| | - greedy_search |
| | - beam_search |
| | - modified_beam_search |
| | - modified_beam_search_LODR |
| | - fast_beam_search |
| | - fast_beam_search_nbest |
| | - fast_beam_search_nbest_oracle |
| | - fast_beam_search_nbest_LG |
| | If you use fast_beam_search_nbest_LG, you have to specify |
| | `--lang-dir`, which should contain `LG.pt`. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--beam-size", |
| | type=int, |
| | default=4, |
| | help="""An integer indicating how many candidates we will keep for each |
| | frame. Used only when --decoding-method is beam_search or |
| | modified_beam_search.""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--beam", |
| | type=float, |
| | default=20.0, |
| | help="""A floating point value to calculate the cutoff score during beam |
| | search (i.e., `cutoff = max-score - beam`), which is the same as the |
| | `beam` in Kaldi. |
| | Used only when --decoding-method is fast_beam_search, |
| | fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| | and fast_beam_search_nbest_oracle |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--ngram-lm-scale", |
| | type=float, |
| | default=0.01, |
| | help=""" |
| | Used only when --decoding-method is fast_beam_search_nbest_LG. |
| | It specifies the scale for n-gram LM scores. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--max-contexts", |
| | type=int, |
| | default=8, |
| | help="""Used only when --decoding-method is |
| | fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| | and fast_beam_search_nbest_oracle""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--max-states", |
| | type=int, |
| | default=64, |
| | help="""Used only when --decoding-method is |
| | fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| | and fast_beam_search_nbest_oracle""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--context-size", |
| | type=int, |
| | default=2, |
| | help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", |
| | ) |
| | parser.add_argument( |
| | "--max-sym-per-frame", |
| | type=int, |
| | default=1, |
| | help="""Maximum number of symbols per frame. |
| | Used only when --decoding-method is greedy_search""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--num-paths", |
| | type=int, |
| | default=200, |
| | help="""Number of paths for nbest decoding. |
| | Used only when the decoding method is fast_beam_search_nbest, |
| | fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--nbest-scale", |
| | type=float, |
| | default=0.5, |
| | help="""Scale applied to lattice scores when computing nbest paths. |
| | Used only when the decoding method is fast_beam_search_nbest, |
| | fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--use-shallow-fusion", |
| | type=str2bool, |
| | default=False, |
| | help="""Use neural network LM for shallow fusion. |
| | If you want to use LODR, you will also need to set this to true |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--lm-type", |
| | type=str, |
| | default="rnn", |
| | help="Type of NN lm", |
| | choices=["rnn", "transformer"], |
| | ) |
| |
|
| | parser.add_argument( |
| | "--lm-scale", |
| | type=float, |
| | default=0.3, |
| | help="""The scale of the neural network LM |
| | Used only when `--use-shallow-fusion` is set to True. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--tokens-ngram", |
| | type=int, |
| | default=2, |
| | help="""The order of the ngram lm. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--backoff-id", |
| | type=int, |
| | default=500, |
| | help="ID of the backoff symbol in the ngram LM", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--context-score", |
| | type=float, |
| | default=2, |
| | help=""" |
| | The bonus score of each token for the context biasing words/phrases. |
| | Used only when --decoding-method is modified_beam_search and |
| | modified_beam_search_LODR. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--context-file", |
| | type=str, |
| | default="", |
| | help=""" |
| | The path of the context biasing lists, one word/phrase each line |
| | Used only when --decoding-method is modified_beam_search and |
| | modified_beam_search_LODR. |
| | """, |
| | ) |
| |
|
| | parser.add_argument( |
| | "--skip-scoring", |
| | type=str2bool, |
| | default=False, |
| | help="""Skip scoring, but still save the ASR output (for eval sets).""", |
| | ) |
| |
|
| | add_model_arguments(parser) |
| |
|
| | return parser |
| |
|
| |
|
| | def decode_one_batch( |
| | params: AttributeDict, |
| | model: nn.Module, |
| | sp: spm.SentencePieceProcessor, |
| | batch: dict, |
| | word_table: Optional[k2.SymbolTable] = None, |
| | decoding_graph: Optional[k2.Fsa] = None, |
| | context_graph: Optional[ContextGraph] = None, |
| | LM: Optional[LmScorer] = None, |
| | ngram_lm=None, |
| | ngram_lm_scale: float = 0.0, |
| | ) -> Dict[str, List[List[str]]]: |
| | """Decode one batch and return the result in a dict. The dict has the |
| | following format: |
| | |
| | - key: It indicates the setting used for decoding. For example, |
| | if greedy_search is used, it would be "greedy_search" |
| | If beam search with a beam size of 7 is used, it would be |
| | "beam_7" |
| | - value: It contains the decoding result. `len(value)` equals to |
| | batch size. `value[i]` is the decoding result for the i-th |
| | utterance in the given batch. |
| | Args: |
| | params: |
| | It's the return value of :func:`get_params`. |
| | model: |
| | The neural model. |
| | sp: |
| | The BPE model. |
| | batch: |
| | It is the return value from iterating |
| | `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation |
| | for the format of the `batch`. |
| | word_table: |
| | The word symbol table. |
| | decoding_graph: |
| | The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used |
| | only when --decoding-method is fast_beam_search, fast_beam_search_nbest, |
| | fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. |
| | LM: |
| | A neural network language model. |
| | ngram_lm: |
| | A ngram language model |
| | ngram_lm_scale: |
| | The scale for the ngram language model. |
| | Returns: |
| | Return the decoding result. See above description for the format of |
| | the returned dict. |
| | """ |
| | device = next(model.parameters()).device |
| | feature = batch["inputs"] |
| | assert feature.ndim == 3 |
| |
|
| | feature = feature.to(device) |
| | |
| |
|
| | supervisions = batch["supervisions"] |
| | feature_lens = supervisions["num_frames"].to(device) |
| |
|
| | if params.causal: |
| | |
| | pad_len = 30 |
| | feature_lens += pad_len |
| | feature = torch.nn.functional.pad( |
| | feature, |
| | pad=(0, 0, 0, pad_len), |
| | value=LOG_EPS, |
| | ) |
| |
|
| | encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) |
| |
|
| | hyps = [] |
| |
|
| | if params.decoding_method == "fast_beam_search": |
| | hyp_tokens = fast_beam_search_one_best( |
| | model=model, |
| | decoding_graph=decoding_graph, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam, |
| | max_contexts=params.max_contexts, |
| | max_states=params.max_states, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "fast_beam_search_nbest_LG": |
| | hyp_tokens = fast_beam_search_nbest_LG( |
| | model=model, |
| | decoding_graph=decoding_graph, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam, |
| | max_contexts=params.max_contexts, |
| | max_states=params.max_states, |
| | num_paths=params.num_paths, |
| | nbest_scale=params.nbest_scale, |
| | ) |
| | for hyp in hyp_tokens: |
| | hyps.append([word_table[i] for i in hyp]) |
| | elif params.decoding_method == "fast_beam_search_nbest": |
| | hyp_tokens = fast_beam_search_nbest( |
| | model=model, |
| | decoding_graph=decoding_graph, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam, |
| | max_contexts=params.max_contexts, |
| | max_states=params.max_states, |
| | num_paths=params.num_paths, |
| | nbest_scale=params.nbest_scale, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "fast_beam_search_nbest_oracle": |
| | hyp_tokens = fast_beam_search_nbest_oracle( |
| | model=model, |
| | decoding_graph=decoding_graph, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam, |
| | max_contexts=params.max_contexts, |
| | max_states=params.max_states, |
| | num_paths=params.num_paths, |
| | ref_texts=sp.encode(supervisions["text"]), |
| | nbest_scale=params.nbest_scale, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: |
| | hyp_tokens = greedy_search_batch( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "modified_beam_search": |
| | hyp_tokens = modified_beam_search( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam_size, |
| | context_graph=context_graph, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": |
| | hyp_tokens = modified_beam_search_lm_shallow_fusion( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam_size, |
| | LM=LM, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "modified_beam_search_LODR": |
| | hyp_tokens = modified_beam_search_LODR( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam_size, |
| | LODR_lm=ngram_lm, |
| | LODR_lm_scale=ngram_lm_scale, |
| | LM=LM, |
| | context_graph=context_graph, |
| | ) |
| | for hyp in sp.decode(hyp_tokens): |
| | hyps.append(hyp.split()) |
| | elif params.decoding_method == "modified_beam_search_lm_rescore": |
| | lm_scale_list = [0.01 * i for i in range(10, 50)] |
| | ans_dict = modified_beam_search_lm_rescore( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam_size, |
| | LM=LM, |
| | lm_scale_list=lm_scale_list, |
| | ) |
| | elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": |
| | lm_scale_list = [0.02 * i for i in range(2, 30)] |
| | ans_dict = modified_beam_search_lm_rescore_LODR( |
| | model=model, |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | beam=params.beam_size, |
| | LM=LM, |
| | LODR_lm=ngram_lm, |
| | sp=sp, |
| | lm_scale_list=lm_scale_list, |
| | ) |
| | else: |
| | batch_size = encoder_out.size(0) |
| |
|
| | for i in range(batch_size): |
| | |
| | encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] |
| | |
| | if params.decoding_method == "greedy_search": |
| | hyp = greedy_search( |
| | model=model, |
| | encoder_out=encoder_out_i, |
| | max_sym_per_frame=params.max_sym_per_frame, |
| | ) |
| | elif params.decoding_method == "beam_search": |
| | hyp = beam_search( |
| | model=model, |
| | encoder_out=encoder_out_i, |
| | beam=params.beam_size, |
| | ) |
| | else: |
| | raise ValueError( |
| | f"Unsupported decoding method: {params.decoding_method}" |
| | ) |
| | hyps.append(sp.decode(hyp).split()) |
| |
|
| | |
| | prefix = f"{params.decoding_method}" |
| | if params.decoding_method == "greedy_search": |
| | return {"greedy_search": hyps} |
| | elif "fast_beam_search" in params.decoding_method: |
| | prefix += f"_beam-{params.beam}" |
| | prefix += f"_max-contexts-{params.max_contexts}" |
| | prefix += f"_max-states-{params.max_states}" |
| | if "nbest" in params.decoding_method: |
| | prefix += f"_num-paths-{params.num_paths}" |
| | prefix += f"_nbest-scale-{params.nbest_scale}" |
| | if "LG" in params.decoding_method: |
| | prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" |
| |
|
| | return {prefix: hyps} |
| | elif "modified_beam_search" in params.decoding_method: |
| | prefix += f"_beam-size-{params.beam_size}" |
| | if params.decoding_method in ( |
| | "modified_beam_search_lm_rescore", |
| | "modified_beam_search_lm_rescore_LODR", |
| | ): |
| | ans = dict() |
| | assert ans_dict is not None |
| | for key, hyps in ans_dict.items(): |
| | hyps = [sp.decode(hyp).split() for hyp in hyps] |
| | ans[f"{prefix}_{key}"] = hyps |
| | return ans |
| | else: |
| | if params.has_contexts: |
| | prefix += f"_context-score-{params.context_score}" |
| | return {prefix: hyps} |
| | else: |
| | prefix += f"_beam-size-{params.beam_size}" |
| | return {prefix: hyps} |
| |
|
| |
|
| | def decode_dataset( |
| | dl: torch.utils.data.DataLoader, |
| | params: AttributeDict, |
| | model: nn.Module, |
| | sp: spm.SentencePieceProcessor, |
| | word_table: Optional[k2.SymbolTable] = None, |
| | decoding_graph: Optional[k2.Fsa] = None, |
| | context_graph: Optional[ContextGraph] = None, |
| | LM: Optional[LmScorer] = None, |
| | ngram_lm=None, |
| | ngram_lm_scale: float = 0.0, |
| | ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: |
| | """Decode dataset. |
| | |
| | Args: |
| | dl: |
| | PyTorch's dataloader containing the dataset to decode. |
| | params: |
| | It is returned by :func:`get_params`. |
| | model: |
| | The neural model. |
| | sp: |
| | The BPE model. |
| | word_table: |
| | The word symbol table. |
| | decoding_graph: |
| | The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used |
| | only when --decoding-method is fast_beam_search, fast_beam_search_nbest, |
| | fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. |
| | Returns: |
| | Return a dict, whose key may be "greedy_search" if greedy search |
| | is used, or it may be "beam_7" if beam size of 7 is used. |
| | Its value is a list of tuples. Each tuple contains two elements: |
| | The first is the reference transcript, and the second is the |
| | predicted result. |
| | """ |
| | num_cuts = 0 |
| |
|
| | try: |
| | num_batches = len(dl) |
| | except TypeError: |
| | num_batches = "?" |
| |
|
| | if params.decoding_method == "greedy_search": |
| | log_interval = 50 |
| | else: |
| | log_interval = 20 |
| |
|
| | results = defaultdict(list) |
| | for batch_idx, batch in enumerate(dl): |
| | texts = batch["supervisions"]["text"] |
| | cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] |
| |
|
| | hyps_dict = decode_one_batch( |
| | params=params, |
| | model=model, |
| | sp=sp, |
| | decoding_graph=decoding_graph, |
| | context_graph=context_graph, |
| | word_table=word_table, |
| | batch=batch, |
| | LM=LM, |
| | ngram_lm=ngram_lm, |
| | ngram_lm_scale=ngram_lm_scale, |
| | ) |
| |
|
| | for name, hyps in hyps_dict.items(): |
| | this_batch = [] |
| | assert len(hyps) == len(texts) |
| | for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): |
| | ref_words = ref_text.split() |
| | this_batch.append((cut_id, ref_words, hyp_words)) |
| |
|
| | results[name].extend(this_batch) |
| |
|
| | num_cuts += len(texts) |
| |
|
| | if batch_idx % log_interval == 0: |
| | batch_str = f"{batch_idx}/{num_batches}" |
| |
|
| | logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") |
| | return results |
| |
|
| |
|
| | def save_asr_output( |
| | params: AttributeDict, |
| | test_set_name: str, |
| | results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], |
| | ): |
| | """ |
| | Save text produced by ASR. |
| | """ |
| | for key, results in results_dict.items(): |
| |
|
| | recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" |
| |
|
| | results = sorted(results) |
| | store_transcripts(filename=recogs_filename, texts=results) |
| |
|
| | logging.info(f"The transcripts are stored in {recogs_filename}") |
| |
|
| |
|
| | def save_wer_results( |
| | params: AttributeDict, |
| | test_set_name: str, |
| | results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], |
| | ): |
| | """ |
| | Save WER and per-utterance word alignments. |
| | """ |
| | test_set_wers = dict() |
| | for key, results in results_dict.items(): |
| | |
| | |
| | errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" |
| | with open(errs_filename, "w", encoding="utf8") as fd: |
| | wer = write_error_stats( |
| | fd, f"{test_set_name}-{key}", results, enable_log=True |
| | ) |
| | test_set_wers[key] = wer |
| |
|
| | logging.info(f"Wrote detailed error stats to {errs_filename}") |
| |
|
| | test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) |
| |
|
| | wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" |
| |
|
| | with open(wer_filename, "w", encoding="utf8") as fd: |
| | print("settings\tWER", file=fd) |
| | for key, val in test_set_wers: |
| | print(f"{key}\t{val}", file=fd) |
| |
|
| | s = f"\nFor {test_set_name}, WER of different settings are:\n" |
| | note = f"\tbest for {test_set_name}" |
| | for key, val in test_set_wers: |
| | s += f"{key}\t{val}{note}\n" |
| | note = "" |
| | logging.info(s) |
| |
|
| |
|
| | @torch.no_grad() |
| | def main(): |
| | parser = get_parser() |
| | LibriSpeechAsrDataModule.add_arguments(parser) |
| | LmScorer.add_arguments(parser) |
| | args = parser.parse_args() |
| | args.exp_dir = Path(args.exp_dir) |
| |
|
| | params = get_params() |
| | params.update(vars(args)) |
| |
|
| | |
| | set_caching_enabled(True) |
| |
|
| | assert params.decoding_method in ( |
| | "greedy_search", |
| | "beam_search", |
| | "fast_beam_search", |
| | "fast_beam_search_nbest", |
| | "fast_beam_search_nbest_LG", |
| | "fast_beam_search_nbest_oracle", |
| | "modified_beam_search", |
| | "modified_beam_search_LODR", |
| | "modified_beam_search_lm_shallow_fusion", |
| | "modified_beam_search_lm_rescore", |
| | "modified_beam_search_lm_rescore_LODR", |
| | ) |
| | params.res_dir = params.exp_dir / params.decoding_method |
| |
|
| | if os.path.exists(params.context_file): |
| | params.has_contexts = True |
| | else: |
| | params.has_contexts = False |
| |
|
| | if params.iter > 0: |
| | params.suffix = f"iter-{params.iter}_avg-{params.avg}" |
| | else: |
| | params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" |
| |
|
| | if params.causal: |
| | assert ( |
| | "," not in params.chunk_size |
| | ), "chunk_size should be one value in decoding." |
| | assert ( |
| | "," not in params.left_context_frames |
| | ), "left_context_frames should be one value in decoding." |
| | params.suffix += f"_chunk-{params.chunk_size}" |
| | params.suffix += f"_left-context-{params.left_context_frames}" |
| |
|
| | if "fast_beam_search" in params.decoding_method: |
| | params.suffix += f"_beam-{params.beam}" |
| | params.suffix += f"_max-contexts-{params.max_contexts}" |
| | params.suffix += f"_max-states-{params.max_states}" |
| | if "nbest" in params.decoding_method: |
| | params.suffix += f"_nbest-scale-{params.nbest_scale}" |
| | params.suffix += f"_num-paths-{params.num_paths}" |
| | if "LG" in params.decoding_method: |
| | params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" |
| | elif "beam_search" in params.decoding_method: |
| | params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" |
| | if params.decoding_method in ( |
| | "modified_beam_search", |
| | "modified_beam_search_LODR", |
| | ): |
| | if params.has_contexts: |
| | params.suffix += f"-context-score-{params.context_score}" |
| | else: |
| | params.suffix += f"_context-{params.context_size}" |
| | params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" |
| |
|
| | if params.use_shallow_fusion: |
| | params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" |
| |
|
| | if "LODR" in params.decoding_method: |
| | params.suffix += ( |
| | f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" |
| | ) |
| |
|
| | if params.use_averaged_model: |
| | params.suffix += "_use-averaged-model" |
| |
|
| | setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") |
| | logging.info("Decoding started") |
| |
|
| | device = torch.device("cpu") |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda", 0) |
| |
|
| | logging.info(f"Device: {device}") |
| |
|
| | sp = spm.SentencePieceProcessor() |
| | sp.load(params.bpe_model) |
| |
|
| | |
| | params.blank_id = sp.piece_to_id("<blk>") |
| | params.unk_id = sp.piece_to_id("<unk>") |
| | params.vocab_size = sp.get_piece_size() |
| |
|
| | logging.info(params) |
| |
|
| | logging.info("About to create model") |
| | model = get_model(params) |
| |
|
| | if not params.use_averaged_model: |
| | if params.iter > 0: |
| | filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| | : params.avg |
| | ] |
| | if len(filenames) == 0: |
| | raise ValueError( |
| | f"No checkpoints found for" |
| | f" --iter {params.iter}, --avg {params.avg}" |
| | ) |
| | elif len(filenames) < params.avg: |
| | raise ValueError( |
| | f"Not enough checkpoints ({len(filenames)}) found for" |
| | f" --iter {params.iter}, --avg {params.avg}" |
| | ) |
| | logging.info(f"averaging {filenames}") |
| | model.to(device) |
| | model.load_state_dict(average_checkpoints(filenames, device=device)) |
| | elif params.avg == 1: |
| | load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) |
| | else: |
| | start = params.epoch - params.avg + 1 |
| | filenames = [] |
| | for i in range(start, params.epoch + 1): |
| | if i >= 1: |
| | filenames.append(f"{params.exp_dir}/epoch-{i}.pt") |
| | logging.info(f"averaging {filenames}") |
| | model.to(device) |
| | model.load_state_dict(average_checkpoints(filenames, device=device)) |
| | else: |
| | if params.iter > 0: |
| | filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| | : params.avg + 1 |
| | ] |
| | if len(filenames) == 0: |
| | raise ValueError( |
| | f"No checkpoints found for" |
| | f" --iter {params.iter}, --avg {params.avg}" |
| | ) |
| | elif len(filenames) < params.avg + 1: |
| | raise ValueError( |
| | f"Not enough checkpoints ({len(filenames)}) found for" |
| | f" --iter {params.iter}, --avg {params.avg}" |
| | ) |
| | filename_start = filenames[-1] |
| | filename_end = filenames[0] |
| | logging.info( |
| | "Calculating the averaged model over iteration checkpoints" |
| | f" from {filename_start} (excluded) to {filename_end}" |
| | ) |
| | model.to(device) |
| | model.load_state_dict( |
| | average_checkpoints_with_averaged_model( |
| | filename_start=filename_start, |
| | filename_end=filename_end, |
| | device=device, |
| | ) |
| | ) |
| | else: |
| | assert params.avg > 0, params.avg |
| | start = params.epoch - params.avg |
| | assert start >= 1, start |
| | filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
| | filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
| | logging.info( |
| | f"Calculating the averaged model over epoch range from " |
| | f"{start} (excluded) to {params.epoch}" |
| | ) |
| | model.to(device) |
| | model.load_state_dict( |
| | average_checkpoints_with_averaged_model( |
| | filename_start=filename_start, |
| | filename_end=filename_end, |
| | device=device, |
| | ) |
| | ) |
| |
|
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | if params.use_shallow_fusion or params.decoding_method in ( |
| | "modified_beam_search_lm_rescore", |
| | "modified_beam_search_lm_rescore_LODR", |
| | "modified_beam_search_lm_shallow_fusion", |
| | "modified_beam_search_LODR", |
| | ): |
| | LM = LmScorer( |
| | lm_type=params.lm_type, |
| | params=params, |
| | device=device, |
| | lm_scale=params.lm_scale, |
| | ) |
| | LM.to(device) |
| | LM.eval() |
| | else: |
| | LM = None |
| |
|
| | |
| | if params.decoding_method == "modified_beam_search_lm_rescore_LODR": |
| | try: |
| | import kenlm |
| | except ImportError: |
| | print("Please install kenlm first. You can use") |
| | print(" pip install https://github.com/kpu/kenlm/archive/master.zip") |
| | print("to install it") |
| | import sys |
| |
|
| | sys.exit(-1) |
| | ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") |
| | logging.info(f"lm filename: {ngram_file_name}") |
| | ngram_lm = kenlm.Model(ngram_file_name) |
| | ngram_lm_scale = None |
| |
|
| | elif params.decoding_method == "modified_beam_search_LODR": |
| | lm_filename = f"{params.tokens_ngram}gram.fst.txt" |
| | logging.info(f"Loading token level lm: {lm_filename}") |
| | ngram_lm = NgramLm( |
| | str(params.lang_dir / lm_filename), |
| | backoff_id=params.backoff_id, |
| | is_binary=False, |
| | ) |
| | logging.info(f"num states: {ngram_lm.lm.num_states}") |
| | ngram_lm_scale = params.ngram_lm_scale |
| | else: |
| | ngram_lm = None |
| | ngram_lm_scale = None |
| |
|
| | if "fast_beam_search" in params.decoding_method: |
| | if params.decoding_method == "fast_beam_search_nbest_LG": |
| | lexicon = Lexicon(params.lang_dir) |
| | word_table = lexicon.word_table |
| | lg_filename = params.lang_dir / "LG.pt" |
| | logging.info(f"Loading {lg_filename}") |
| | decoding_graph = k2.Fsa.from_dict( |
| | torch.load(lg_filename, map_location=device, weights_only=False) |
| | ) |
| | decoding_graph.scores *= params.ngram_lm_scale |
| | else: |
| | word_table = None |
| | decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) |
| | else: |
| | decoding_graph = None |
| | word_table = None |
| |
|
| | if "modified_beam_search" in params.decoding_method: |
| | if os.path.exists(params.context_file): |
| | contexts = [] |
| | for line in open(params.context_file).readlines(): |
| | contexts.append((sp.encode(line.strip()), 0.0)) |
| | context_graph = ContextGraph(params.context_score) |
| | context_graph.build(contexts) |
| | else: |
| | context_graph = None |
| | else: |
| | context_graph = None |
| |
|
| | num_param = sum([p.numel() for p in model.parameters()]) |
| | logging.info(f"Number of model parameters: {num_param}") |
| |
|
| | |
| | args.return_cuts = True |
| | librispeech = LibriSpeechAsrDataModule(args) |
| |
|
| | test_clean_cuts = librispeech.test_clean_cuts() |
| | test_other_cuts = librispeech.test_other_cuts() |
| |
|
| | test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) |
| | test_other_dl = librispeech.test_dataloaders(test_other_cuts) |
| |
|
| | test_sets = ["test-clean", "test-other"] |
| | test_dl = [test_clean_dl, test_other_dl] |
| |
|
| | for test_set, test_dl in zip(test_sets, test_dl): |
| | results_dict = decode_dataset( |
| | dl=test_dl, |
| | params=params, |
| | model=model, |
| | sp=sp, |
| | word_table=word_table, |
| | decoding_graph=decoding_graph, |
| | context_graph=context_graph, |
| | LM=LM, |
| | ngram_lm=ngram_lm, |
| | ngram_lm_scale=ngram_lm_scale, |
| | ) |
| |
|
| | save_asr_output( |
| | params=params, |
| | test_set_name=test_set, |
| | results_dict=results_dict, |
| | ) |
| |
|
| | if not params.skip_scoring: |
| | save_wer_results( |
| | params=params, |
| | test_set_name=test_set, |
| | results_dict=results_dict, |
| | ) |
| |
|
| | logging.info("Done!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|