| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| This script loads a checkpoint and uses it to decode waves. |
| You can generate the checkpoint with the following command: |
| |
| Note: This is a example for librispeech dataset, if you are using different |
| dataset, you should change the argument values according to your dataset. |
| |
| - For non-streaming model: |
| |
| ./zipformer/export.py \ |
| --exp-dir ./zipformer/exp \ |
| --tokens data/lang_bpe_500/tokens.txt \ |
| --epoch 30 \ |
| --avg 9 |
| |
| - For streaming model: |
| |
| ./zipformer/export.py \ |
| --exp-dir ./zipformer/exp \ |
| --causal 1 \ |
| --tokens data/lang_bpe_500/tokens.txt \ |
| --epoch 30 \ |
| --avg 9 |
| |
| Usage of this script: |
| |
| - For non-streaming model: |
| |
| (1) greedy search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --tokens data/lang_bpe_500/tokens.txt \ |
| --method greedy_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| (2) modified beam search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --tokens ./data/lang_bpe_500/tokens.txt \ |
| --method modified_beam_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| (3) fast beam search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --tokens ./data/lang_bpe_500/tokens.txt \ |
| --method fast_beam_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| - For streaming model: |
| |
| (1) greedy search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --causal 1 \ |
| --chunk-size 16 \ |
| --left-context-frames 128 \ |
| --tokens ./data/lang_bpe_500/tokens.txt \ |
| --method greedy_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| (2) modified beam search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --causal 1 \ |
| --chunk-size 16 \ |
| --left-context-frames 128 \ |
| --tokens ./data/lang_bpe_500/tokens.txt \ |
| --method modified_beam_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| (3) fast beam search |
| ./zipformer/pretrained.py \ |
| --checkpoint ./zipformer/exp/pretrained.pt \ |
| --causal 1 \ |
| --chunk-size 16 \ |
| --left-context-frames 128 \ |
| --tokens ./data/lang_bpe_500/tokens.txt \ |
| --method fast_beam_search \ |
| /path/to/foo.wav \ |
| /path/to/bar.wav |
| |
| |
| You can also use `./zipformer/exp/epoch-xx.pt`. |
| |
| Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py |
| """ |
|
|
|
|
| import argparse |
| import logging |
| import math |
| from typing import List |
|
|
| import k2 |
| import kaldifeat |
| import torch |
| import torchaudio |
| from beam_search import ( |
| fast_beam_search_one_best, |
| greedy_search_batch, |
| modified_beam_search, |
| ) |
| from export import num_tokens |
| from torch.nn.utils.rnn import pad_sequence |
| from train import add_model_arguments, get_model, get_params |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| required=True, |
| help="Path to the checkpoint. " |
| "The checkpoint is assumed to be saved by " |
| "icefall.checkpoint.save_checkpoint().", |
| ) |
|
|
| parser.add_argument( |
| "--tokens", |
| type=str, |
| help="""Path to tokens.txt.""", |
| ) |
|
|
| parser.add_argument( |
| "--method", |
| type=str, |
| default="greedy_search", |
| help="""Possible values are: |
| - greedy_search |
| - modified_beam_search |
| - fast_beam_search |
| """, |
| ) |
|
|
| parser.add_argument( |
| "sound_files", |
| type=str, |
| nargs="+", |
| help="The input sound file(s) to transcribe. " |
| "Supported formats are those supported by torchaudio.load(). " |
| "For example, wav and flac are supported. " |
| "The sample rate has to be 16kHz.", |
| ) |
|
|
| parser.add_argument( |
| "--sample-rate", |
| type=int, |
| default=16000, |
| help="The sample rate of the input sound file", |
| ) |
|
|
| parser.add_argument( |
| "--beam-size", |
| type=int, |
| default=4, |
| help="""An integer indicating how many candidates we will keep for each |
| frame. Used only when --method is beam_search or |
| modified_beam_search.""", |
| ) |
|
|
| parser.add_argument( |
| "--beam", |
| type=float, |
| default=4, |
| help="""A floating point value to calculate the cutoff score during beam |
| search (i.e., `cutoff = max-score - beam`), which is the same as the |
| `beam` in Kaldi. |
| Used only when --method is fast_beam_search""", |
| ) |
|
|
| parser.add_argument( |
| "--max-contexts", |
| type=int, |
| default=4, |
| help="""Used only when --method is fast_beam_search""", |
| ) |
|
|
| parser.add_argument( |
| "--max-states", |
| type=int, |
| default=8, |
| help="""Used only when --method is fast_beam_search""", |
| ) |
|
|
| parser.add_argument( |
| "--context-size", |
| type=int, |
| default=2, |
| help="The context size in the decoder. 1 means bigram; 2 means tri-gram", |
| ) |
|
|
| parser.add_argument( |
| "--max-sym-per-frame", |
| type=int, |
| default=1, |
| help="""Maximum number of symbols per frame. Used only when |
| --method is greedy_search. |
| """, |
| ) |
|
|
| add_model_arguments(parser) |
|
|
| return parser |
|
|
|
|
| def read_sound_files( |
| filenames: List[str], expected_sample_rate: float |
| ) -> List[torch.Tensor]: |
| """Read a list of sound files into a list 1-D float32 torch tensors. |
| Args: |
| filenames: |
| A list of sound filenames. |
| expected_sample_rate: |
| The expected sample rate of the sound files. |
| Returns: |
| Return a list of 1-D float32 torch tensors. |
| """ |
| ans = [] |
| for f in filenames: |
| wave, sample_rate = torchaudio.load(f) |
| assert ( |
| sample_rate == expected_sample_rate |
| ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" |
| |
| ans.append(wave[0].contiguous()) |
| return ans |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| parser = get_parser() |
| args = parser.parse_args() |
|
|
| params = get_params() |
|
|
| params.update(vars(args)) |
|
|
| token_table = k2.SymbolTable.from_file(params.tokens) |
|
|
| params.blank_id = token_table["<blk>"] |
| params.unk_id = token_table["<unk>"] |
| params.vocab_size = num_tokens(token_table) + 1 |
|
|
| logging.info(f"{params}") |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", 0) |
|
|
| logging.info(f"device: {device}") |
|
|
| if params.causal: |
| assert ( |
| "," not in params.chunk_size |
| ), "chunk_size should be one value in decoding." |
| assert ( |
| "," not in params.left_context_frames |
| ), "left_context_frames should be one value in decoding." |
|
|
| logging.info("Creating model") |
| model = get_model(params) |
|
|
| num_param = sum([p.numel() for p in model.parameters()]) |
| logging.info(f"Number of model parameters: {num_param}") |
|
|
| checkpoint = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
| model.load_state_dict(checkpoint["model"], strict=False) |
| model.to(device) |
| model.eval() |
|
|
| logging.info("Constructing Fbank computer") |
| opts = kaldifeat.FbankOptions() |
| opts.device = device |
| opts.frame_opts.dither = 0 |
| opts.frame_opts.snip_edges = False |
| opts.frame_opts.samp_freq = params.sample_rate |
| opts.mel_opts.num_bins = params.feature_dim |
| opts.mel_opts.high_freq = -400 |
|
|
| fbank = kaldifeat.Fbank(opts) |
|
|
| logging.info(f"Reading sound files: {params.sound_files}") |
| waves = read_sound_files( |
| filenames=params.sound_files, expected_sample_rate=params.sample_rate |
| ) |
| waves = [w.to(device) for w in waves] |
|
|
| logging.info("Decoding started") |
| features = fbank(waves) |
| feature_lengths = [f.size(0) for f in features] |
|
|
| features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) |
| feature_lengths = torch.tensor(feature_lengths, device=device) |
|
|
| |
| encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) |
|
|
| hyps = [] |
| msg = f"Using {params.method}" |
| logging.info(msg) |
|
|
| def token_ids_to_words(token_ids: List[int]) -> str: |
| text = "" |
| for i in token_ids: |
| text += token_table[i] |
| return text.replace("▁", " ").strip() |
|
|
| if params.method == "fast_beam_search": |
| decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) |
| hyp_tokens = fast_beam_search_one_best( |
| model=model, |
| decoding_graph=decoding_graph, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam, |
| max_contexts=params.max_contexts, |
| max_states=params.max_states, |
| ) |
| for hyp in hyp_tokens: |
| hyps.append(token_ids_to_words(hyp)) |
| elif params.method == "modified_beam_search": |
| hyp_tokens = modified_beam_search( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| beam=params.beam_size, |
| ) |
|
|
| for hyp in hyp_tokens: |
| hyps.append(token_ids_to_words(hyp)) |
| elif params.method == "greedy_search" and params.max_sym_per_frame == 1: |
| hyp_tokens = greedy_search_batch( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| ) |
| for hyp in hyp_tokens: |
| hyps.append(token_ids_to_words(hyp)) |
| else: |
| raise ValueError(f"Unsupported method: {params.method}") |
|
|
| s = "\n" |
| for filename, hyp in zip(params.sound_files, hyps): |
| s += f"{filename}:\n{hyp}\n\n" |
| logging.info(s) |
|
|
| logging.info("Decoding Done") |
|
|
|
|
| if __name__ == "__main__": |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
| logging.basicConfig(format=formatter, level=logging.INFO) |
| main() |
|
|