| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Usage: |
| (1) greedy search |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method greedy_search |
| |
| (2) beam search (not recommended) |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method beam_search \ |
| --beam-size 4 |
| |
| (3) modified beam search |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method modified_beam_search \ |
| --beam-size 4 |
| |
| (4) fast beam search (one best) |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method fast_beam_search \ |
| --beam 20.0 \ |
| --max-contexts 8 \ |
| --max-states 64 |
| |
| (5) fast beam search (nbest) |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method fast_beam_search_nbest \ |
| --beam 20.0 \ |
| --max-contexts 8 \ |
| --max-states 64 \ |
| --num-paths 200 \ |
| --nbest-scale 0.5 |
| |
| (6) fast beam search (nbest oracle WER) |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method fast_beam_search_nbest_oracle \ |
| --beam 20.0 \ |
| --max-contexts 8 \ |
| --max-states 64 \ |
| --num-paths 200 \ |
| --nbest-scale 0.5 |
| |
| (7) fast beam search (with LG) |
| ./zipformer/decode.py \ |
| --epoch 28 \ |
| --avg 15 \ |
| --exp-dir ./zipformer/exp \ |
| --max-duration 600 \ |
| --decoding-method fast_beam_search_nbest_LG \ |
| --beam 20.0 \ |
| --max-contexts 8 \ |
| --max-states 64 |
| """ |
|
|
|
|
| import argparse |
| import logging |
| import math |
| import os |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import k2 |
| import sentencepiece as spm |
| import torch |
| import torch.nn as nn |
| from asr_datamodule import LibriSpeechAsrDataModule |
| from beam_search import ( |
| beam_search, |
| fast_beam_search_nbest, |
| fast_beam_search_nbest_LG, |
| fast_beam_search_nbest_oracle, |
| fast_beam_search_one_best, |
| greedy_search, |
| greedy_search_batch, |
| modified_beam_search, |
| modified_beam_search_lm_rescore, |
| modified_beam_search_lm_rescore_LODR, |
| modified_beam_search_lm_shallow_fusion, |
| modified_beam_search_LODR, |
| ) |
| from train import add_model_arguments, get_model, get_params |
|
|
| from icefall import ContextGraph, LmScorer, NgramLm |
| from icefall.checkpoint import ( |
| average_checkpoints, |
| average_checkpoints_with_averaged_model, |
| find_checkpoints, |
| load_checkpoint, |
| ) |
| from icefall.lexicon import Lexicon |
| from icefall.utils import ( |
| AttributeDict, |
| make_pad_mask, |
| setup_logger, |
| store_transcripts, |
| str2bool, |
| write_error_stats, |
| ) |
|
|
| LOG_EPS = math.log(1e-10) |
|
|
| conversational_filler = [ |
| "UH", |
| "UHH", |
| "UM", |
| "EH", |
| "MM", |
| "HM", |
| "AH", |
| "HUH", |
| "HA", |
| "ER", |
| "OOF", |
| "HEE", |
| "ACH", |
| "EEE", |
| "EW", |
| ] |
| unk_tags = ["<UNK>", "<unk>"] |
| gigaspeech_punctuations = [ |
| "<COMMA>", |
| "<PERIOD>", |
| "<QUESTIONMARK>", |
| "<EXCLAMATIONPOINT>", |
| ] |
| gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"] |
| non_scoring_words = ( |
| conversational_filler |
| + unk_tags |
| + gigaspeech_punctuations |
| + gigaspeech_garbage_utterance_tags |
| ) |
|
|
|
|
| def asr_text_post_processing(text: str) -> str: |
| |
| text = text.upper() |
|
|
| |
| |
| text = text.replace("-", " ") |
|
|
| |
| remaining_words = [] |
| for word in text.split(): |
| if word in non_scoring_words: |
| continue |
| remaining_words.append(word) |
|
|
| return " ".join(remaining_words) |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--epoch", |
| type=int, |
| default=30, |
| help="""It specifies the checkpoint to use for decoding. |
| Note: Epoch counts from 1. |
| You can specify --avg to use more checkpoints for model averaging.""", |
| ) |
|
|
| parser.add_argument( |
| "--iter", |
| type=int, |
| default=0, |
| help="""If positive, --epoch is ignored and it |
| will use the checkpoint exp_dir/checkpoint-iter.pt. |
| You can specify --avg to use more checkpoints for model averaging. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--avg", |
| type=int, |
| default=15, |
| help="Number of checkpoints to average. Automatically select " |
| "consecutive checkpoints before the checkpoint specified by " |
| "'--epoch' and '--iter'", |
| ) |
|
|
| parser.add_argument( |
| "--use-averaged-model", |
| type=str2bool, |
| default=True, |
| help="Whether to load averaged model. Currently it only supports " |
| "using --epoch. If True, it would decode with the averaged model " |
| "over the epoch range from `epoch-avg` (excluded) to `epoch`." |
| "Actually only the models with epoch number of `epoch-avg` and " |
| "`epoch` are loaded for averaging. ", |
| ) |
|
|
| parser.add_argument( |
| "--exp-dir", |
| type=str, |
| default="zipformer/exp", |
| help="The experiment dir", |
| ) |
|
|
| parser.add_argument( |
| "--bpe-model", |
| type=str, |
| default="data/lang_bpe_500/bpe.model", |
| help="Path to the BPE model", |
| ) |
|
|
| parser.add_argument( |
| "--lang-dir", |
| type=Path, |
| default="data/lang_bpe_500", |
| help="The lang dir containing word table and LG graph", |
| ) |
|
|
| parser.add_argument( |
| "--decoding-method", |
| type=str, |
| default="greedy_search", |
| help="""Possible values are: |
| - greedy_search |
| - beam_search |
| - modified_beam_search |
| - modified_beam_search_LODR |
| - fast_beam_search |
| - fast_beam_search_nbest |
| - fast_beam_search_nbest_oracle |
| - fast_beam_search_nbest_LG |
| If you use fast_beam_search_nbest_LG, you have to specify |
| `--lang-dir`, which should contain `LG.pt`. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--beam-size", |
| type=int, |
| default=4, |
| help="""An integer indicating how many candidates we will keep for each |
| frame. Used only when --decoding-method is beam_search or |
| modified_beam_search.""", |
| ) |
|
|
| parser.add_argument( |
| "--beam", |
| type=float, |
| default=20.0, |
| help="""A floating point value to calculate the cutoff score during beam |
| search (i.e., `cutoff = max-score - beam`), which is the same as the |
| `beam` in Kaldi. |
| Used only when --decoding-method is fast_beam_search, |
| fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| and fast_beam_search_nbest_oracle |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--ngram-lm-scale", |
| type=float, |
| default=0.01, |
| help=""" |
| Used only when --decoding-method is fast_beam_search_nbest_LG. |
| It specifies the scale for n-gram LM scores. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--max-contexts", |
| type=int, |
| default=8, |
| help="""Used only when --decoding-method is |
| fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| and fast_beam_search_nbest_oracle""", |
| ) |
|
|
| parser.add_argument( |
| "--max-states", |
| type=int, |
| default=64, |
| help="""Used only when --decoding-method is |
| fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, |
| and fast_beam_search_nbest_oracle""", |
| ) |
|
|
| parser.add_argument( |
| "--context-size", |
| type=int, |
| default=2, |
| help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", |
| ) |
| parser.add_argument( |
| "--max-sym-per-frame", |
| type=int, |
| default=1, |
| help="""Maximum number of symbols per frame. |
| Used only when --decoding-method is greedy_search""", |
| ) |
|
|
| parser.add_argument( |
| "--num-paths", |
| type=int, |
| default=200, |
| help="""Number of paths for nbest decoding. |
| Used only when the decoding method is fast_beam_search_nbest, |
| fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", |
| ) |
|
|
| parser.add_argument( |
| "--nbest-scale", |
| type=float, |
| default=0.5, |
| help="""Scale applied to lattice scores when computing nbest paths. |
| Used only when the decoding method is fast_beam_search_nbest, |
| fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", |
| ) |
|
|
| parser.add_argument( |
| "--use-shallow-fusion", |
| type=str2bool, |
| default=False, |
| help="""Use neural network LM for shallow fusion. |
| If you want to use LODR, you will also need to set this to true |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--lm-type", |
| type=str, |
| default="rnn", |
| help="Type of NN lm", |
| choices=["rnn", "transformer"], |
| ) |
|
|
| parser.add_argument( |
| "--lm-scale", |
| type=float, |
| default=0.3, |
| help="""The scale of the neural network LM |
| Used only when `--use-shallow-fusion` is set to True. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--tokens-ngram", |
| type=int, |
| default=2, |
| help="""The order of the ngram lm. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--backoff-id", |
| type=int, |
| default=500, |
| help="ID of the backoff symbol in the ngram LM", |
| ) |
|
|
| parser.add_argument( |
| "--context-score", |
| type=float, |
| default=2, |
| help=""" |
| The bonus score of each token for the context biasing words/phrases. |
| Used only when --decoding-method is modified_beam_search and |
| modified_beam_search_LODR. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--context-file", |
| type=str, |
| default="", |
| help=""" |
| The path of the context biasing lists, one word/phrase each line |
| Used only when --decoding-method is modified_beam_search and |
| modified_beam_search_LODR. |
| """, |
| ) |
| add_model_arguments(parser) |
|
|
| return parser |
|
|
|
|
| def post_processing( |
| results: List[Tuple[str, List[str], List[str]]], |
| ) -> List[Tuple[str, List[str], List[str]]]: |
| new_results = [] |
| for key, ref, hyp in results: |
| new_ref = asr_text_post_processing(" ".join(ref)).split() |
| new_hyp = asr_text_post_processing(" ".join(hyp)).split() |
| new_results.append((key, new_ref, new_hyp)) |
| return new_results |
|
|
|
|
| def decode_one_batch( |
| params: AttributeDict, |
| model: nn.Module, |
| sp: spm.SentencePieceProcessor, |
| batch: dict, |
| word_table: Optional[k2.SymbolTable] = None, |
| decoding_graph: Optional[k2.Fsa] = None, |
| context_graph: Optional[ContextGraph] = None, |
| LM: Optional[LmScorer] = None, |
| ngram_lm=None, |
| ngram_lm_scale: float = 0.0, |
| ) -> Dict[str, List[List[str]]]: |
| """Decode one batch and return the result in a dict. The dict has the |
| following format: |
| |
| - key: It indicates the setting used for decoding. For example, |
| if greedy_search is used, it would be "greedy_search" |
| If beam search with a beam size of 7 is used, it would be |
| "beam_7" |
| - value: It contains the decoding result. `len(value)` equals to |
| batch size. `value[i]` is the decoding result for the i-th |
| utterance in the given batch. |
| Args: |
| params: |
| It's the return value of :func:`get_params`. |
| model: |
| The neural model. |
| sp: |
| The BPE model. |
| batch: |
| It is the return value from iterating |
| `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation |
| for the format of the `batch`. |
| word_table: |
| The word symbol table. |
| decoding_graph: |
| The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used |
| only when --decoding-method is fast_beam_search, fast_beam_search_nbest, |
| fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. |
| LM: |
| A neural network language model. |
| ngram_lm: |
| A ngram language model |
| ngram_lm_scale: |
| The scale for the ngram language model. |
| Returns: |
| Return the decoding result. See above description for the format of |
| the returned dict. |
| """ |
| device = next(model.parameters()).device |
| feature = batch["inputs"] |
| assert feature.ndim == 3 |
|
|
| feature = feature.to(device) |
| |
|
|
| supervisions = batch["supervisions"] |
| feature_lens = supervisions["num_frames"].to(device) |
|
|
| if params.causal: |
| |
| pad_len = 30 |
| feature_lens += pad_len |
| feature = torch.nn.functional.pad( |
| feature, |
| pad=(0, 0, 0, pad_len), |
| value=LOG_EPS, |
| ) |
|
|
| encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) |
|
|
| hyps = [] |
|
|
| if params.decoding_method == "fast_beam_search": |
| hyp_tokens = fast_beam_search_one_best( |
| model=model, |
| decoding_graph=decoding_graph, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam, |
| max_contexts=params.max_contexts, |
| max_states=params.max_states, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "fast_beam_search_nbest_LG": |
| hyp_tokens = fast_beam_search_nbest_LG( |
| model=model, |
| decoding_graph=decoding_graph, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam, |
| max_contexts=params.max_contexts, |
| max_states=params.max_states, |
| num_paths=params.num_paths, |
| nbest_scale=params.nbest_scale, |
| ) |
| for hyp in hyp_tokens: |
| hyps.append([word_table[i] for i in hyp]) |
| elif params.decoding_method == "fast_beam_search_nbest": |
| hyp_tokens = fast_beam_search_nbest( |
| model=model, |
| decoding_graph=decoding_graph, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam, |
| max_contexts=params.max_contexts, |
| max_states=params.max_states, |
| num_paths=params.num_paths, |
| nbest_scale=params.nbest_scale, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "fast_beam_search_nbest_oracle": |
| hyp_tokens = fast_beam_search_nbest_oracle( |
| model=model, |
| decoding_graph=decoding_graph, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam, |
| max_contexts=params.max_contexts, |
| max_states=params.max_states, |
| num_paths=params.num_paths, |
| ref_texts=sp.encode(supervisions["text"]), |
| nbest_scale=params.nbest_scale, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: |
| hyp_tokens = greedy_search_batch( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "modified_beam_search": |
| hyp_tokens = modified_beam_search( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| context_graph=context_graph, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": |
| hyp_tokens = modified_beam_search_lm_shallow_fusion( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| LM=LM, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "modified_beam_search_LODR": |
| hyp_tokens = modified_beam_search_LODR( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| LODR_lm=ngram_lm, |
| LODR_lm_scale=ngram_lm_scale, |
| LM=LM, |
| context_graph=context_graph, |
| ) |
| for hyp in sp.decode(hyp_tokens): |
| hyps.append(hyp.split()) |
| elif params.decoding_method == "modified_beam_search_lm_rescore": |
| lm_scale_list = [0.01 * i for i in range(10, 50)] |
| ans_dict = modified_beam_search_lm_rescore( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| LM=LM, |
| lm_scale_list=lm_scale_list, |
| ) |
| elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": |
| lm_scale_list = [0.02 * i for i in range(2, 30)] |
| ans_dict = modified_beam_search_lm_rescore_LODR( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| LM=LM, |
| LODR_lm=ngram_lm, |
| sp=sp, |
| lm_scale_list=lm_scale_list, |
| ) |
| else: |
| batch_size = encoder_out.size(0) |
|
|
| for i in range(batch_size): |
| |
| encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] |
| |
| if params.decoding_method == "greedy_search": |
| hyp = greedy_search( |
| model=model, |
| encoder_out=encoder_out_i, |
| max_sym_per_frame=params.max_sym_per_frame, |
| ) |
| elif params.decoding_method == "beam_search": |
| hyp = beam_search( |
| model=model, |
| encoder_out=encoder_out_i, |
| beam=params.beam_size, |
| ) |
| else: |
| raise ValueError( |
| f"Unsupported decoding method: {params.decoding_method}" |
| ) |
| hyps.append(sp.decode(hyp).split()) |
|
|
| if params.decoding_method == "greedy_search": |
| return {"greedy_search": hyps} |
| elif "fast_beam_search" in params.decoding_method: |
| key = f"beam_{params.beam}_" |
| key += f"max_contexts_{params.max_contexts}_" |
| key += f"max_states_{params.max_states}" |
| if "nbest" in params.decoding_method: |
| key += f"_num_paths_{params.num_paths}_" |
| key += f"nbest_scale_{params.nbest_scale}" |
| if "LG" in params.decoding_method: |
| key += f"_ngram_lm_scale_{params.ngram_lm_scale}" |
|
|
| return {key: hyps} |
| elif "modified_beam_search" in params.decoding_method: |
| prefix = f"beam_size_{params.beam_size}" |
| if params.decoding_method in ( |
| "modified_beam_search_lm_rescore", |
| "modified_beam_search_lm_rescore_LODR", |
| ): |
| ans = dict() |
| assert ans_dict is not None |
| for key, hyps in ans_dict.items(): |
| hyps = [sp.decode(hyp).split() for hyp in hyps] |
| ans[f"{prefix}_{key}"] = hyps |
| return ans |
| else: |
| if params.has_contexts: |
| prefix += f"-context-score-{params.context_score}" |
| return {prefix: hyps} |
| else: |
| return {f"beam_size_{params.beam_size}": hyps} |
|
|
|
|
| def decode_dataset( |
| dl: torch.utils.data.DataLoader, |
| params: AttributeDict, |
| model: nn.Module, |
| sp: spm.SentencePieceProcessor, |
| word_table: Optional[k2.SymbolTable] = None, |
| decoding_graph: Optional[k2.Fsa] = None, |
| context_graph: Optional[ContextGraph] = None, |
| LM: Optional[LmScorer] = None, |
| ngram_lm=None, |
| ngram_lm_scale: float = 0.0, |
| ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: |
| """Decode dataset. |
| |
| Args: |
| dl: |
| PyTorch's dataloader containing the dataset to decode. |
| params: |
| It is returned by :func:`get_params`. |
| model: |
| The neural model. |
| sp: |
| The BPE model. |
| word_table: |
| The word symbol table. |
| decoding_graph: |
| The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used |
| only when --decoding-method is fast_beam_search, fast_beam_search_nbest, |
| fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. |
| Returns: |
| Return a dict, whose key may be "greedy_search" if greedy search |
| is used, or it may be "beam_7" if beam size of 7 is used. |
| Its value is a list of tuples. Each tuple contains two elements: |
| The first is the reference transcript, and the second is the |
| predicted result. |
| """ |
| num_cuts = 0 |
|
|
| try: |
| num_batches = len(dl) |
| except TypeError: |
| num_batches = "?" |
|
|
| if params.decoding_method == "greedy_search": |
| log_interval = 50 |
| else: |
| log_interval = 20 |
|
|
| results = defaultdict(list) |
| for batch_idx, batch in enumerate(dl): |
| texts = batch["supervisions"]["text"] |
| cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] |
|
|
| hyps_dict = decode_one_batch( |
| params=params, |
| model=model, |
| sp=sp, |
| decoding_graph=decoding_graph, |
| context_graph=context_graph, |
| word_table=word_table, |
| batch=batch, |
| LM=LM, |
| ngram_lm=ngram_lm, |
| ngram_lm_scale=ngram_lm_scale, |
| ) |
|
|
| for name, hyps in hyps_dict.items(): |
| this_batch = [] |
| assert len(hyps) == len(texts) |
| for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): |
| ref_words = ref_text.split() |
| this_batch.append((cut_id, ref_words, hyp_words)) |
|
|
| results[name].extend(this_batch) |
|
|
| num_cuts += len(texts) |
|
|
| if batch_idx % log_interval == 0: |
| batch_str = f"{batch_idx}/{num_batches}" |
|
|
| logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") |
| return results |
|
|
|
|
| def save_results( |
| params: AttributeDict, |
| test_set_name: str, |
| results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], |
| ): |
| test_set_wers = dict() |
| for key, results in results_dict.items(): |
| recog_path = ( |
| params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| results = post_processing(results) |
| results = sorted(results) |
| store_transcripts(filename=recog_path, texts=results) |
| logging.info(f"The transcripts are stored in {recog_path}") |
|
|
| |
| |
| errs_filename = ( |
| params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| with open(errs_filename, "w") as f: |
| wer = write_error_stats( |
| f, f"{test_set_name}-{key}", results, enable_log=True |
| ) |
| test_set_wers[key] = wer |
|
|
| logging.info("Wrote detailed error stats to {}".format(errs_filename)) |
|
|
| test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) |
| errs_info = ( |
| params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| with open(errs_info, "w") as f: |
| print("settings\tWER", file=f) |
| for key, val in test_set_wers: |
| print("{}\t{}".format(key, val), file=f) |
|
|
| s = "\nFor {}, WER of different settings are:\n".format(test_set_name) |
| note = "\tbest for {}".format(test_set_name) |
| for key, val in test_set_wers: |
| s += "{}\t{}{}\n".format(key, val, note) |
| note = "" |
| logging.info(s) |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| parser = get_parser() |
| LibriSpeechAsrDataModule.add_arguments(parser) |
| LmScorer.add_arguments(parser) |
| args = parser.parse_args() |
| args.exp_dir = Path(args.exp_dir) |
|
|
| params = get_params() |
| params.update(vars(args)) |
|
|
| assert params.decoding_method in ( |
| "greedy_search", |
| "beam_search", |
| "fast_beam_search", |
| "fast_beam_search_nbest", |
| "fast_beam_search_nbest_LG", |
| "fast_beam_search_nbest_oracle", |
| "modified_beam_search", |
| "modified_beam_search_LODR", |
| "modified_beam_search_lm_shallow_fusion", |
| "modified_beam_search_lm_rescore", |
| "modified_beam_search_lm_rescore_LODR", |
| ) |
| params.res_dir = params.exp_dir / params.decoding_method |
|
|
| if os.path.exists(params.context_file): |
| params.has_contexts = True |
| else: |
| params.has_contexts = False |
|
|
| if params.iter > 0: |
| params.suffix = f"iter-{params.iter}-avg-{params.avg}" |
| else: |
| params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" |
|
|
| if params.causal: |
| assert ( |
| "," not in params.chunk_size |
| ), "chunk_size should be one value in decoding." |
| assert ( |
| "," not in params.left_context_frames |
| ), "left_context_frames should be one value in decoding." |
| params.suffix += f"-chunk-{params.chunk_size}" |
| params.suffix += f"-left-context-{params.left_context_frames}" |
|
|
| if "fast_beam_search" in params.decoding_method: |
| params.suffix += f"-beam-{params.beam}" |
| params.suffix += f"-max-contexts-{params.max_contexts}" |
| params.suffix += f"-max-states-{params.max_states}" |
| if "nbest" in params.decoding_method: |
| params.suffix += f"-nbest-scale-{params.nbest_scale}" |
| params.suffix += f"-num-paths-{params.num_paths}" |
| if "LG" in params.decoding_method: |
| params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" |
| elif "beam_search" in params.decoding_method: |
| params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" |
| if params.decoding_method in ( |
| "modified_beam_search", |
| "modified_beam_search_LODR", |
| ): |
| if params.has_contexts: |
| params.suffix += f"-context-score-{params.context_score}" |
| else: |
| params.suffix += f"-context-{params.context_size}" |
| params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" |
|
|
| if params.use_shallow_fusion: |
| params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" |
|
|
| if "LODR" in params.decoding_method: |
| params.suffix += ( |
| f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" |
| ) |
|
|
| if params.use_averaged_model: |
| params.suffix += "-use-averaged-model" |
|
|
| setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") |
| logging.info("Decoding started") |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", 0) |
|
|
| logging.info(f"Device: {device}") |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.load(params.bpe_model) |
|
|
| |
| params.blank_id = sp.piece_to_id("<blk>") |
| params.unk_id = sp.piece_to_id("<unk>") |
| params.vocab_size = sp.get_piece_size() |
|
|
| logging.info(params) |
|
|
| logging.info("About to create model") |
| model = get_model(params) |
|
|
| if not params.use_averaged_model: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict(average_checkpoints(filenames, device=device)) |
| elif params.avg == 1: |
| load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) |
| else: |
| start = params.epoch - params.avg + 1 |
| filenames = [] |
| for i in range(start, params.epoch + 1): |
| if i >= 1: |
| filenames.append(f"{params.exp_dir}/epoch-{i}.pt") |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict(average_checkpoints(filenames, device=device)) |
| else: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg + 1 |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg + 1: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| filename_start = filenames[-1] |
| filename_end = filenames[0] |
| logging.info( |
| "Calculating the averaged model over iteration checkpoints" |
| f" from {filename_start} (excluded) to {filename_end}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ) |
| ) |
| else: |
| assert params.avg > 0, params.avg |
| start = params.epoch - params.avg |
| assert start >= 1, start |
| filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
| filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
| logging.info( |
| f"Calculating the averaged model over epoch range from " |
| f"{start} (excluded) to {params.epoch}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ) |
| ) |
|
|
| model.to(device) |
| model.eval() |
|
|
| |
| if params.use_shallow_fusion or params.decoding_method in ( |
| "modified_beam_search_lm_rescore", |
| "modified_beam_search_lm_rescore_LODR", |
| "modified_beam_search_lm_shallow_fusion", |
| "modified_beam_search_LODR", |
| ): |
| LM = LmScorer( |
| lm_type=params.lm_type, |
| params=params, |
| device=device, |
| lm_scale=params.lm_scale, |
| ) |
| LM.to(device) |
| LM.eval() |
| else: |
| LM = None |
|
|
| |
| if params.decoding_method == "modified_beam_search_lm_rescore_LODR": |
| try: |
| import kenlm |
| except ImportError: |
| print("Please install kenlm first. You can use") |
| print(" pip install https://github.com/kpu/kenlm/archive/master.zip") |
| print("to install it") |
| import sys |
|
|
| sys.exit(-1) |
| ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") |
| logging.info(f"lm filename: {ngram_file_name}") |
| ngram_lm = kenlm.Model(ngram_file_name) |
| ngram_lm_scale = None |
|
|
| elif params.decoding_method == "modified_beam_search_LODR": |
| lm_filename = f"{params.tokens_ngram}gram.fst.txt" |
| logging.info(f"Loading token level lm: {lm_filename}") |
| ngram_lm = NgramLm( |
| str(params.lang_dir / lm_filename), |
| backoff_id=params.backoff_id, |
| is_binary=False, |
| ) |
| logging.info(f"num states: {ngram_lm.lm.num_states}") |
| ngram_lm_scale = params.ngram_lm_scale |
| else: |
| ngram_lm = None |
| ngram_lm_scale = None |
|
|
| if "fast_beam_search" in params.decoding_method: |
| if params.decoding_method == "fast_beam_search_nbest_LG": |
| lexicon = Lexicon(params.lang_dir) |
| word_table = lexicon.word_table |
| lg_filename = params.lang_dir / "LG.pt" |
| logging.info(f"Loading {lg_filename}") |
| decoding_graph = k2.Fsa.from_dict( |
| torch.load(lg_filename, map_location=device, weights_only=False) |
| ) |
| decoding_graph.scores *= params.ngram_lm_scale |
| else: |
| word_table = None |
| decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) |
| else: |
| decoding_graph = None |
| word_table = None |
|
|
| if "modified_beam_search" in params.decoding_method: |
| if os.path.exists(params.context_file): |
| contexts = [] |
| for line in open(params.context_file).readlines(): |
| contexts.append((sp.encode(line.strip()), 0.0)) |
| context_graph = ContextGraph(params.context_score) |
| context_graph.build(contexts) |
| else: |
| context_graph = None |
| else: |
| context_graph = None |
|
|
| num_param = sum([p.numel() for p in model.parameters()]) |
| logging.info(f"Number of model parameters: {num_param}") |
|
|
| |
| args.return_cuts = True |
| librispeech = LibriSpeechAsrDataModule(args) |
|
|
| gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts() |
| gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts() |
|
|
| dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts) |
| test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts) |
|
|
| test_sets = ["dev", "test"] |
| test_dl = [dev_dl, test_dl] |
|
|
| for test_set, test_dl in zip(test_sets, test_dl): |
| results_dict = decode_dataset( |
| dl=test_dl, |
| params=params, |
| model=model, |
| sp=sp, |
| word_table=word_table, |
| decoding_graph=decoding_graph, |
| context_graph=context_graph, |
| LM=LM, |
| ngram_lm=ngram_lm, |
| ngram_lm_scale=ngram_lm_scale, |
| ) |
|
|
| save_results( |
| params=params, |
| test_set_name=test_set, |
| results_dict=results_dict, |
| ) |
|
|
| logging.info("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|