| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Usage: |
| |
| (1) ctc-greedy-search (with cr-ctc) |
| ./zipformer/ctc_decode.py \ |
| --epoch 60 \ |
| --avg 28 \ |
| --exp-dir ./zipformer/exp \ |
| --use-cr-ctc 1 \ |
| --use-ctc 1 \ |
| --use-transducer 0 \ |
| --max-duration 600 \ |
| --decoding-method ctc-greedy-search |
| (2) ctc-prefix-beam-search (with cr-ctc) |
| ./zipformer/ctc_decode.py \ |
| --epoch 60 \ |
| --avg 21 \ |
| --exp-dir zipformer/exp \ |
| --use-cr-ctc 1 \ |
| --use-ctc 1 \ |
| --use-transducer 0 \ |
| --max-duration 600 \ |
| --decoding-method ctc-prefix-beam-search |
| """ |
|
|
|
|
| import argparse |
| import logging |
| import math |
| import os |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import k2 |
| import torch |
| import torch.nn as nn |
| from asr_datamodule import AishellAsrDataModule |
| from lhotse.cut import Cut |
| from train import add_model_arguments, get_model, get_params |
|
|
| from icefall.checkpoint import ( |
| average_checkpoints, |
| average_checkpoints_with_averaged_model, |
| find_checkpoints, |
| load_checkpoint, |
| ) |
| from icefall.decode import ( |
| ctc_greedy_search, |
| ctc_prefix_beam_search, |
| ) |
| from icefall.lexicon import Lexicon |
| from icefall.utils import ( |
| AttributeDict, |
| make_pad_mask, |
| setup_logger, |
| store_transcripts, |
| str2bool, |
| write_error_stats, |
| ) |
|
|
| LOG_EPS = math.log(1e-10) |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--epoch", |
| type=int, |
| default=30, |
| help="""It specifies the checkpoint to use for decoding. |
| Note: Epoch counts from 1. |
| You can specify --avg to use more checkpoints for model averaging.""", |
| ) |
|
|
| parser.add_argument( |
| "--iter", |
| type=int, |
| default=0, |
| help="""If positive, --epoch is ignored and it |
| will use the checkpoint exp_dir/checkpoint-iter.pt. |
| You can specify --avg to use more checkpoints for model averaging. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--avg", |
| type=int, |
| default=15, |
| help="Number of checkpoints to average. Automatically select " |
| "consecutive checkpoints before the checkpoint specified by " |
| "'--epoch' and '--iter'", |
| ) |
|
|
| parser.add_argument( |
| "--use-averaged-model", |
| type=str2bool, |
| default=True, |
| help="Whether to load averaged model. Currently it only supports " |
| "using --epoch. If True, it would decode with the averaged model " |
| "over the epoch range from `epoch-avg` (excluded) to `epoch`." |
| "Actually only the models with epoch number of `epoch-avg` and " |
| "`epoch` are loaded for averaging. ", |
| ) |
|
|
| parser.add_argument( |
| "--exp-dir", |
| type=str, |
| default="zipformer/exp", |
| help="The experiment dir", |
| ) |
|
|
| parser.add_argument( |
| "--lang-dir", |
| type=Path, |
| default="data/lang_char", |
| help="The lang dir containing word table and LG graph", |
| ) |
|
|
| parser.add_argument( |
| "--decoding-method", |
| type=str, |
| default="ctc-greedy-search", |
| help="""Decoding method. |
| Supported values are: |
| - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece |
| model, i.e., lang_dir/bpe.model, to convert word pieces to words. |
| It needs neither a lexicon nor an n-gram LM. |
| (2) ctc-prefix-beam-search. Extract n paths with the given beam, the best |
| path of the n paths is the decoding result. |
| """, |
| ) |
|
|
| add_model_arguments(parser) |
|
|
| return parser |
|
|
|
|
| def get_decoding_params() -> AttributeDict: |
| """Parameters for decoding.""" |
| params = AttributeDict( |
| { |
| "beam": 4, |
| } |
| ) |
| return params |
|
|
| def decode_one_batch( |
| params: AttributeDict, |
| model: nn.Module, |
| lexicon: Lexicon, |
| batch: dict, |
| ) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]: |
| """Decode one batch and return the result in a dict. The dict has the |
| following format: |
| |
| - key: It indicates the setting used for decoding. For example, |
| if greedy_search is used, it would be "greedy_search" |
| If beam search with a beam size of 7 is used, it would be |
| "beam_7" |
| - value: It contains the decoding result. `len(value)` equals to |
| batch size. `value[i]` is the decoding result for the i-th |
| utterance in the given batch. |
| Args: |
| params: |
| It's the return value of :func:`get_params`. |
| model: |
| The neural model. |
| batch: |
| It is the return value from iterating |
| `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation |
| for the format of the `batch`. |
| Returns: |
| Return the decoding result. See above description for the format of |
| the returned dict. |
| """ |
| device = next(model.parameters()).device |
| feature = batch["inputs"] |
| assert feature.ndim == 3 |
|
|
| feature = feature.to(device) |
| |
|
|
| supervisions = batch["supervisions"] |
| feature_lens = supervisions["num_frames"].to(device) |
|
|
| if params.causal: |
| |
| pad_len = 30 |
| feature_lens += pad_len |
| feature = torch.nn.functional.pad( |
| feature, |
| pad=(0, 0, 0, pad_len), |
| value=LOG_EPS, |
| ) |
|
|
| x, x_lens = model.encoder_embed(feature, feature_lens) |
|
|
| src_key_padding_mask = make_pad_mask(x_lens) |
| x = x.permute(1, 0, 2) |
|
|
| encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) |
| encoder_out = encoder_out.permute(1, 0, 2) |
|
|
| ctc_output = model.ctc_output(encoder_out) |
|
|
| hyp_tokens = [] |
| hyps = [] |
|
|
| if params.decoding_method == "ctc-greedy-search": |
| hyp_tokens = ctc_greedy_search( |
| ctc_output=ctc_output, |
| encoder_out_lens=encoder_out_lens, |
| ) |
| elif params.decoding_method == "ctc-prefix-beam-search": |
| hyp_tokens = ctc_prefix_beam_search( |
| ctc_output=ctc_output, |
| encoder_out_lens=encoder_out_lens, |
| ) |
| else: |
| raise ValueError( |
| f"Unsupported decoding method: {params.decoding_method}" |
| ) |
| |
| for i in range(encoder_out.size(0)): |
| hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) |
|
|
| if params.decoding_method == "ctc-greedy-search": |
| return {"ctc-greedy-search" : hyps} |
| elif params.decoding_method == "ctc-prefix-beam-search": |
| return {"ctc-prefix-beam-search" : hyps} |
| else: |
| assert False, f"Unsupported decoding method: {params.decoding_method}" |
|
|
|
|
| def decode_dataset( |
| dl: torch.utils.data.DataLoader, |
| params: AttributeDict, |
| model: nn.Module, |
| lexicon: Lexicon, |
| ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: |
| """Decode dataset. |
| |
| Args: |
| dl: |
| PyTorch's dataloader containing the dataset to decode. |
| params: |
| It is returned by :func:`get_params`. |
| model: |
| The neural model. |
| Returns: |
| Return a dict, whose key may be "greedy_search" if greedy search |
| is used, or it may be "beam_7" if beam size of 7 is used. |
| Its value is a list of tuples. Each tuple contains 3 elements: |
| Respectively, they are cut_id, the reference transcript, and the predicted result. |
| """ |
| num_cuts = 0 |
|
|
| try: |
| num_batches = len(dl) |
| except TypeError: |
| num_batches = "?" |
|
|
| log_interval = 20 |
|
|
| results = defaultdict(list) |
| for batch_idx, batch in enumerate(dl): |
| texts = batch["supervisions"]["text"] |
| texts = [list("".join(text.split())) for text in texts] |
| cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] |
|
|
| hyps_dict = decode_one_batch( |
| params=params, |
| model=model, |
| lexicon=lexicon, |
| batch=batch, |
| ) |
| for name, hyps in hyps_dict.items(): |
| this_batch = [] |
| assert len(hyps) == len(texts) |
| for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): |
| this_batch.append((cut_id, ref_text, hyp_words)) |
| results[name].extend(this_batch) |
| |
| num_cuts += len(texts) |
|
|
| if batch_idx % log_interval == 0: |
| batch_str = f"{batch_idx}/{num_batches}" |
|
|
| logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") |
| return results |
|
|
|
|
| def save_results( |
| params: AttributeDict, |
| test_set_name: str, |
| results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], |
| ): |
| test_set_wers = dict() |
| for key, results in results_dict.items(): |
| recog_path = ( |
| params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| results = sorted(results) |
| store_transcripts(filename=recog_path, texts=results, char_level = True) |
| logging.info(f"The transcripts are stored in {recog_path}") |
|
|
| |
| |
| errs_filename = ( |
| params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| with open(errs_filename, "w") as f: |
| wer = write_error_stats( |
| f, |
| f"{test_set_name}-{key}", |
| results, |
| enable_log=True, |
| compute_CER=True, |
| ) |
| test_set_wers[key] = wer |
|
|
| logging.info("Wrote detailed error stats to {}".format(errs_filename)) |
|
|
| test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) |
| errs_info = ( |
| params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" |
| ) |
| with open(errs_info, "w") as f: |
| print("settings\tWER", file=f) |
| for key, val in test_set_wers: |
| print("{}\t{}".format(key, val), file=f) |
|
|
| s = "\nFor {}, WER of different settings are:\n".format(test_set_name) |
| note = "\tbest for {}".format(test_set_name) |
| for key, val in test_set_wers: |
| s += "{}\t{}{}\n".format(key, val, note) |
| note = "" |
| logging.info(s) |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| parser = get_parser() |
| AishellAsrDataModule.add_arguments(parser) |
| args = parser.parse_args() |
| args.exp_dir = Path(args.exp_dir) |
| args.lang_dir = Path(args.lang_dir) |
|
|
| params = get_params() |
| |
| params.update(get_decoding_params()) |
| params.update(vars(args)) |
|
|
| assert params.decoding_method in ( |
| "ctc-greedy-search", |
| "ctc-prefix-beam-search", |
| ) |
| params.res_dir = params.exp_dir / params.decoding_method |
|
|
| if params.iter > 0: |
| params.suffix = f"iter-{params.iter}-avg-{params.avg}" |
| else: |
| params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" |
|
|
| if params.causal: |
| assert ( |
| "," not in params.chunk_size |
| ), "chunk_size should be one value in decoding." |
| assert ( |
| "," not in params.left_context_frames |
| ), "left_context_frames should be one value in decoding." |
| params.suffix += f"-chunk-{params.chunk_size}" |
| params.suffix += f"-left-context-{params.left_context_frames}" |
|
|
| if "prefix-beam-search" in params.decoding_method: |
| params.suffix += f"_beam-{params.beam}" |
|
|
| if params.use_averaged_model: |
| params.suffix += "-use-averaged-model" |
|
|
| setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") |
| logging.info("Decoding started") |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", 0) |
| params.device = device |
|
|
| logging.info(f"Device: {device}") |
|
|
| lexicon = Lexicon(params.lang_dir) |
| |
| params.blank_id = lexicon.token_table["<blk>"] |
| params.vocab_size = max(lexicon.tokens) + 1 |
|
|
| logging.info(params) |
|
|
| logging.info("About to create model") |
| model = get_model(params) |
|
|
| if not params.use_averaged_model: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict(average_checkpoints(filenames, device=device)) |
| elif params.avg == 1: |
| load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) |
| else: |
| start = params.epoch - params.avg + 1 |
| filenames = [] |
| for i in range(start, params.epoch + 1): |
| if i >= 1: |
| filenames.append(f"{params.exp_dir}/epoch-{i}.pt") |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict(average_checkpoints(filenames, device=device)) |
| else: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg + 1 |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg + 1: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| filename_start = filenames[-1] |
| filename_end = filenames[0] |
| logging.info( |
| "Calculating the averaged model over iteration checkpoints" |
| f" from {filename_start} (excluded) to {filename_end}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ) |
| ) |
| else: |
| assert params.avg > 0, params.avg |
| start = params.epoch - params.avg |
| assert start >= 1, start |
| filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
| filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
| logging.info( |
| f"Calculating the averaged model over epoch range from " |
| f"{start} (excluded) to {params.epoch}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ) |
| ) |
|
|
| model.to(device) |
| model.eval() |
|
|
| num_param = sum([p.numel() for p in model.parameters()]) |
| logging.info(f"Number of model parameters: {num_param}") |
|
|
| |
| args.return_cuts = True |
| aishell = AishellAsrDataModule(args) |
|
|
| dev_cuts = aishell.valid_cuts() |
| dev_dl = aishell.valid_dataloaders(dev_cuts) |
|
|
| test_cuts = aishell.test_cuts() |
| test_dl = aishell.test_dataloaders(test_cuts) |
|
|
| test_sets = ["dev", "test"] |
| test_dls = [dev_dl, test_dl] |
|
|
| for test_set, test_dl in zip(test_sets, test_dls): |
| results_dict = decode_dataset( |
| dl=test_dl, |
| params=params, |
| model=model, |
| lexicon=lexicon, |
| ) |
|
|
| save_results( |
| params=params, |
| test_set_name=test_set, |
| results_dict=results_dict, |
| ) |
|
|
| logging.info("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|