| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import print_function |
| |
|
| | import argparse |
| | import copy |
| | import logging |
| | import os |
| | import sys |
| |
|
| | import torch |
| | import yaml |
| | from torch.utils.data import DataLoader |
| |
|
| | from wenet.dataset.dataset import Dataset |
| | from wenet.paraformer.search.beam_search import build_beam_search |
| | from wenet.utils.checkpoint import load_checkpoint |
| | from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols |
| | from wenet.utils.config import override_config |
| | from wenet.utils.init_model import init_model |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser(description="recognize with your model") |
| | parser.add_argument("--config", required=True, help="config file") |
| | parser.add_argument("--test_data", required=True, help="test data file") |
| | parser.add_argument( |
| | "--data_type", |
| | default="raw", |
| | choices=["raw", "shard"], |
| | help="train and cv data type", |
| | ) |
| | parser.add_argument( |
| | "--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
| | ) |
| | parser.add_argument("--checkpoint", required=True, help="checkpoint model") |
| | parser.add_argument("--dict", required=True, help="dict file") |
| | parser.add_argument( |
| | "--non_lang_syms", help="non-linguistic symbol file. One symbol per line." |
| | ) |
| | parser.add_argument( |
| | "--beam_size", type=int, default=10, help="beam size for search" |
| | ) |
| | parser.add_argument("--penalty", type=float, default=0.0, help="length penalty") |
| | parser.add_argument("--result_file", required=True, help="asr result file") |
| | parser.add_argument("--batch_size", type=int, default=16, help="asr result file") |
| | parser.add_argument( |
| | "--mode", |
| | choices=[ |
| | "attention", |
| | "ctc_greedy_search", |
| | "ctc_prefix_beam_search", |
| | "attention_rescoring", |
| | "rnnt_greedy_search", |
| | "rnnt_beam_search", |
| | "rnnt_beam_attn_rescoring", |
| | "ctc_beam_td_attn_rescoring", |
| | "hlg_onebest", |
| | "hlg_rescore", |
| | "paraformer_greedy_search", |
| | "paraformer_beam_search", |
| | ], |
| | default="attention", |
| | help="decoding mode", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--search_ctc_weight", |
| | type=float, |
| | default=1.0, |
| | help="ctc weight for nbest generation", |
| | ) |
| | parser.add_argument( |
| | "--search_transducer_weight", |
| | type=float, |
| | default=0.0, |
| | help="transducer weight for nbest generation", |
| | ) |
| | parser.add_argument( |
| | "--ctc_weight", |
| | type=float, |
| | default=0.0, |
| | help="ctc weight for rescoring weight in \ |
| | attention rescoring decode mode \ |
| | ctc weight for rescoring weight in \ |
| | transducer attention rescore decode mode", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--transducer_weight", |
| | type=float, |
| | default=0.0, |
| | help="transducer weight for rescoring weight in " |
| | "transducer attention rescore mode", |
| | ) |
| | parser.add_argument( |
| | "--attn_weight", |
| | type=float, |
| | default=0.0, |
| | help="attention weight for rescoring weight in " |
| | "transducer attention rescore mode", |
| | ) |
| | parser.add_argument( |
| | "--decoding_chunk_size", |
| | type=int, |
| | default=-1, |
| | help="""decoding chunk size, |
| | <0: for decoding, use full chunk. |
| | >0: for decoding, use fixed chunk size as set. |
| | 0: used for training, it's prohibited here""", |
| | ) |
| | parser.add_argument( |
| | "--num_decoding_left_chunks", |
| | type=int, |
| | default=-1, |
| | help="number of left chunks for decoding", |
| | ) |
| | parser.add_argument( |
| | "--simulate_streaming", action="store_true", help="simulate streaming inference" |
| | ) |
| | parser.add_argument( |
| | "--reverse_weight", |
| | type=float, |
| | default=0.0, |
| | help="""right to left weight for attention rescoring |
| | decode mode""", |
| | ) |
| | parser.add_argument( |
| | "--bpe_model", default=None, type=str, help="bpe model for english part" |
| | ) |
| | parser.add_argument( |
| | "--override_config", action="append", default=[], help="override yaml config" |
| | ) |
| | parser.add_argument( |
| | "--connect_symbol", |
| | default="", |
| | type=str, |
| | help="used to connect the output characters", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--word", default="", type=str, help="word file, only used for hlg decode" |
| | ) |
| | parser.add_argument( |
| | "--hlg", default="", type=str, help="hlg file, only used for hlg decode" |
| | ) |
| | parser.add_argument( |
| | "--lm_scale", |
| | type=float, |
| | default=0.0, |
| | help="lm scale for hlg attention rescore decode", |
| | ) |
| | parser.add_argument( |
| | "--decoder_scale", |
| | type=float, |
| | default=0.0, |
| | help="lm scale for hlg attention rescore decode", |
| | ) |
| | parser.add_argument( |
| | "--r_decoder_scale", |
| | type=float, |
| | default=0.0, |
| | help="lm scale for hlg attention rescore decode", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | print(args) |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| | logging.basicConfig( |
| | level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
| | ) |
| | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
| |
|
| | if ( |
| | args.mode |
| | in [ |
| | "ctc_prefix_beam_search", |
| | "attention_rescoring", |
| | "paraformer_beam_search", |
| | ] |
| | and args.batch_size > 1 |
| | ): |
| | logging.fatal( |
| | "decoding mode {} must be running with batch_size == 1".format(args.mode) |
| | ) |
| | sys.exit(1) |
| |
|
| | with open(args.config, "r") as fin: |
| | configs = yaml.load(fin, Loader=yaml.FullLoader) |
| | if len(args.override_config) > 0: |
| | configs = override_config(configs, args.override_config) |
| |
|
| | symbol_table = read_symbol_table(args.dict) |
| | test_conf = copy.deepcopy(configs["dataset_conf"]) |
| |
|
| | test_conf["filter_conf"]["max_length"] = 102400 |
| | test_conf["filter_conf"]["min_length"] = 0 |
| | test_conf["filter_conf"]["token_max_length"] = 102400 |
| | test_conf["filter_conf"]["token_min_length"] = 0 |
| | test_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
| | test_conf["filter_conf"]["min_output_input_ratio"] = 0 |
| | test_conf["speed_perturb"] = False |
| | test_conf["spec_aug"] = False |
| | test_conf["spec_sub"] = False |
| | test_conf["spec_trim"] = False |
| | test_conf["shuffle"] = False |
| | test_conf["sort"] = False |
| | if "fbank_conf" in test_conf: |
| | test_conf["fbank_conf"]["dither"] = 0.0 |
| | elif "mfcc_conf" in test_conf: |
| | test_conf["mfcc_conf"]["dither"] = 0.0 |
| | test_conf["batch_conf"]["batch_type"] = "static" |
| | test_conf["batch_conf"]["batch_size"] = args.batch_size |
| | non_lang_syms = read_non_lang_symbols(args.non_lang_syms) |
| |
|
| | test_dataset = Dataset( |
| | args.data_type, |
| | args.test_data, |
| | symbol_table, |
| | test_conf, |
| | args.bpe_model, |
| | non_lang_syms, |
| | partition=False, |
| | ) |
| |
|
| | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) |
| |
|
| | |
| | model = init_model(configs) |
| |
|
| | |
| | char_dict = {v: k for k, v in symbol_table.items()} |
| | eos = len(char_dict) - 1 |
| |
|
| | load_checkpoint(model, args.checkpoint) |
| | use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
| | device = torch.device("cuda" if use_cuda else "cpu") |
| | model = model.to(device) |
| |
|
| | model.eval() |
| |
|
| | |
| | if args.mode == "paraformer_beam_search": |
| | paraformer_beam_search = build_beam_search(model, args, device) |
| | else: |
| | paraformer_beam_search = None |
| |
|
| | with torch.no_grad(), open(args.result_file, "w") as fout: |
| | for batch_idx, batch in enumerate(test_data_loader): |
| | keys, feats, target, feats_lengths, target_lengths = batch |
| | feats = feats.to(device) |
| | target = target.to(device) |
| | feats_lengths = feats_lengths.to(device) |
| | target_lengths = target_lengths.to(device) |
| | if args.mode == "attention": |
| | hyps, _ = model.recognize( |
| | feats, |
| | feats_lengths, |
| | beam_size=args.beam_size, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | hyps = [hyp.tolist() for hyp in hyps] |
| | elif args.mode == "ctc_greedy_search": |
| | hyps, _ = model.ctc_greedy_search( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | elif args.mode == "rnnt_greedy_search": |
| | assert feats.size(0) == 1 |
| | assert "predictor" in configs |
| | hyps = model.greedy_search( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | elif args.mode == "rnnt_beam_search": |
| | assert feats.size(0) == 1 |
| | assert "predictor" in configs |
| | hyps = model.beam_search( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | beam_size=args.beam_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ctc_weight=args.search_ctc_weight, |
| | transducer_weight=args.search_transducer_weight, |
| | ) |
| | elif args.mode == "rnnt_beam_attn_rescoring": |
| | assert feats.size(0) == 1 |
| | assert "predictor" in configs |
| | hyps = model.transducer_attention_rescoring( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | beam_size=args.beam_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ctc_weight=args.ctc_weight, |
| | transducer_weight=args.transducer_weight, |
| | attn_weight=args.attn_weight, |
| | reverse_weight=args.reverse_weight, |
| | search_ctc_weight=args.search_ctc_weight, |
| | search_transducer_weight=args.search_transducer_weight, |
| | ) |
| | elif args.mode == "ctc_beam_td_attn_rescoring": |
| | assert feats.size(0) == 1 |
| | assert "predictor" in configs |
| | hyps = model.transducer_attention_rescoring( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | beam_size=args.beam_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ctc_weight=args.ctc_weight, |
| | transducer_weight=args.transducer_weight, |
| | attn_weight=args.attn_weight, |
| | reverse_weight=args.reverse_weight, |
| | search_ctc_weight=args.search_ctc_weight, |
| | search_transducer_weight=args.search_transducer_weight, |
| | beam_search_type="ctc", |
| | ) |
| | |
| | |
| | |
| | elif args.mode == "ctc_prefix_beam_search": |
| | assert feats.size(0) == 1 |
| | hyp, _ = model.ctc_prefix_beam_search( |
| | feats, |
| | feats_lengths, |
| | args.beam_size, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | hyps = [hyp] |
| | elif args.mode == "attention_rescoring": |
| | assert feats.size(0) == 1 |
| | hyp, _ = model.attention_rescoring( |
| | feats, |
| | feats_lengths, |
| | args.beam_size, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | ctc_weight=args.ctc_weight, |
| | simulate_streaming=args.simulate_streaming, |
| | reverse_weight=args.reverse_weight, |
| | ) |
| | hyps = [hyp] |
| | elif args.mode == "hlg_onebest": |
| | hyps = model.hlg_onebest( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | hlg=args.hlg, |
| | word=args.word, |
| | symbol_table=symbol_table, |
| | ) |
| | elif args.mode == "hlg_rescore": |
| | hyps = model.hlg_rescore( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | lm_scale=args.lm_scale, |
| | decoder_scale=args.decoder_scale, |
| | r_decoder_scale=args.r_decoder_scale, |
| | hlg=args.hlg, |
| | word=args.word, |
| | symbol_table=symbol_table, |
| | ) |
| | elif args.mode == "paraformer_beam_search": |
| | hyps = model.paraformer_beam_search( |
| | feats, |
| | feats_lengths, |
| | beam_search=paraformer_beam_search, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | elif args.mode == "paraformer_greedy_search": |
| | hyps = model.paraformer_greedy_search( |
| | feats, |
| | feats_lengths, |
| | decoding_chunk_size=args.decoding_chunk_size, |
| | num_decoding_left_chunks=args.num_decoding_left_chunks, |
| | simulate_streaming=args.simulate_streaming, |
| | ) |
| | for i, key in enumerate(keys): |
| | content = [] |
| | for w in hyps[i]: |
| | if w == eos: |
| | break |
| | content.append(char_dict[w]) |
| | logging.info("{} {}".format(key, args.connect_symbol.join(content))) |
| | fout.write("{} {}\n".format(key, args.connect_symbol.join(content))) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|