| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import print_function |
|
|
| import argparse |
| import copy |
| import logging |
| import os |
|
|
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
|
|
| from wenet.dataset.dataset import Dataset |
| from wenet.utils.config import override_config |
| from wenet.utils.init_model import init_model |
| from wenet.utils.init_tokenizer import init_tokenizer |
| from wenet.utils.context_graph import ContextGraph |
| from wenet.utils.ctc_utils import get_blank_id |
| from wenet.utils.common import TORCH_NPU_AVAILABLE |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='recognize with your model') |
| parser.add_argument('--config', required=True, help='config file') |
| parser.add_argument('--test_data', required=True, help='test data file') |
| parser.add_argument('--data_type', |
| default='raw', |
| choices=['raw', 'shard'], |
| help='train and cv data type') |
| parser.add_argument('--gpu', |
| type=int, |
| default=-1, |
| help='gpu id for this rank, -1 for cpu') |
| parser.add_argument('--device', |
| type=str, |
| default="cpu", |
| choices=["cpu", "npu", "cuda"], |
| help='accelerator to use') |
| parser.add_argument('--dtype', |
| type=str, |
| default='fp32', |
| choices=['fp16', 'fp32', 'bf16'], |
| help='model\'s dtype') |
| parser.add_argument('--num_workers', |
| default=0, |
| type=int, |
| help='num of subprocess workers for reading') |
| parser.add_argument('--checkpoint', required=True, help='checkpoint model') |
| parser.add_argument('--beam_size', |
| type=int, |
| default=10, |
| help='beam size for search') |
| parser.add_argument('--length_penalty', |
| type=float, |
| default=0.0, |
| help='length penalty') |
| parser.add_argument('--blank_penalty', |
| type=float, |
| default=0.0, |
| help='blank penalty') |
| parser.add_argument('--result_dir', required=True, help='asr result file') |
| parser.add_argument('--batch_size', |
| type=int, |
| default=16, |
| help='asr result file') |
| parser.add_argument('--modes', |
| nargs='+', |
| help="""decoding mode, support the following: |
| attention |
| ctc_greedy_search |
| ctc_prefix_beam_search |
| attention_rescoring |
| rnnt_greedy_search |
| rnnt_beam_search |
| rnnt_beam_attn_rescoring |
| ctc_beam_td_attn_rescoring |
| hlg_onebest |
| hlg_rescore |
| paraformer_greedy_search |
| paraformer_beam_search""") |
| parser.add_argument('--search_ctc_weight', |
| type=float, |
| default=1.0, |
| help='ctc weight for nbest generation') |
| parser.add_argument('--search_transducer_weight', |
| type=float, |
| default=0.0, |
| help='transducer weight for nbest generation') |
| parser.add_argument('--ctc_weight', |
| type=float, |
| default=0.0, |
| help='ctc weight for rescoring weight in \ |
| attention rescoring decode mode \ |
| ctc weight for rescoring weight in \ |
| transducer attention rescore decode mode') |
|
|
| parser.add_argument('--transducer_weight', |
| type=float, |
| default=0.0, |
| help='transducer weight for rescoring weight in ' |
| 'transducer attention rescore mode') |
| parser.add_argument('--attn_weight', |
| type=float, |
| default=0.0, |
| help='attention weight for rescoring weight in ' |
| 'transducer attention rescore mode') |
| parser.add_argument('--decoding_chunk_size', |
| type=int, |
| default=-1, |
| help='''decoding chunk size, |
| <0: for decoding, use full chunk. |
| >0: for decoding, use fixed chunk size as set. |
| 0: used for training, it's prohibited here''') |
| parser.add_argument('--num_decoding_left_chunks', |
| type=int, |
| default=-1, |
| help='number of left chunks for decoding') |
| parser.add_argument('--simulate_streaming', |
| action='store_true', |
| help='simulate streaming inference') |
| parser.add_argument('--reverse_weight', |
| type=float, |
| default=0.0, |
| help='''right to left weight for attention rescoring |
| decode mode''') |
| parser.add_argument('--override_config', |
| action='append', |
| default=[], |
| help="override yaml config") |
|
|
| parser.add_argument('--word', |
| default='', |
| type=str, |
| help='word file, only used for hlg decode') |
| parser.add_argument('--hlg', |
| default='', |
| type=str, |
| help='hlg file, only used for hlg decode') |
| parser.add_argument('--lm_scale', |
| type=float, |
| default=0.0, |
| help='lm scale for hlg attention rescore decode') |
| parser.add_argument('--decoder_scale', |
| type=float, |
| default=0.0, |
| help='lm scale for hlg attention rescore decode') |
| parser.add_argument('--r_decoder_scale', |
| type=float, |
| default=0.0, |
| help='lm scale for hlg attention rescore decode') |
|
|
| parser.add_argument( |
| '--context_bias_mode', |
| type=str, |
| default='', |
| help='''Context bias mode, selectable from the following |
| option: decoding-graph, deep-biasing''') |
| parser.add_argument('--context_list_path', |
| type=str, |
| default='', |
| help='Context list path') |
| parser.add_argument('--context_graph_score', |
| type=float, |
| default=0.0, |
| help='''The higher the score, the greater the degree of |
| bias using decoding-graph for biasing''') |
|
|
| parser.add_argument('--use_lora', |
| type=bool, |
| default=False, |
| help='''Whether to use lora for biasing''') |
| parser.add_argument("--lora_ckpt_path", |
| default=None, |
| type=str, |
| help="lora checkpoint path.") |
| args = parser.parse_args() |
| print(args) |
| return args |
|
|
|
|
| def main(): |
| args = get_args() |
| logging.basicConfig(level=logging.DEBUG, |
| format='%(asctime)s %(levelname)s %(message)s') |
| if args.gpu != -1: |
| |
| args.device = "cuda" |
| if "cuda" in args.device: |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
|
| with open(args.config, 'r') as fin: |
| configs = yaml.load(fin, Loader=yaml.FullLoader) |
| if len(args.override_config) > 0: |
| configs = override_config(configs, args.override_config) |
|
|
| test_conf = copy.deepcopy(configs['dataset_conf']) |
|
|
| test_conf['filter_conf']['max_length'] = 102400 |
| test_conf['filter_conf']['min_length'] = 0 |
| test_conf['filter_conf']['token_max_length'] = 102400 |
| test_conf['filter_conf']['token_min_length'] = 0 |
| test_conf['filter_conf']['max_output_input_ratio'] = 102400 |
| test_conf['filter_conf']['min_output_input_ratio'] = 0 |
| test_conf['speed_perturb'] = False |
| test_conf['spec_aug'] = False |
| test_conf['spec_sub'] = False |
| test_conf['spec_trim'] = False |
| test_conf['shuffle'] = False |
| test_conf['sort'] = False |
| test_conf['cycle'] = 1 |
| test_conf['list_shuffle'] = False |
| if 'fbank_conf' in test_conf: |
| test_conf['fbank_conf']['dither'] = 0.0 |
| elif 'mfcc_conf' in test_conf: |
| test_conf['mfcc_conf']['dither'] = 0.0 |
| test_conf['batch_conf']['batch_type'] = "static" |
| test_conf['batch_conf']['batch_size'] = args.batch_size |
|
|
| tokenizer = init_tokenizer(configs) |
| test_dataset = Dataset(args.data_type, |
| args.test_data, |
| tokenizer, |
| test_conf, |
| partition=False) |
|
|
| test_data_loader = DataLoader(test_dataset, |
| batch_size=None, |
| num_workers=args.num_workers) |
|
|
| |
| args.jit = False |
| model, configs = init_model(args, configs) |
|
|
| device = torch.device(args.device) |
| model = model.to(device) |
| model.eval() |
| dtype = torch.float32 |
| if args.dtype == 'fp16': |
| dtype = torch.float16 |
| elif args.dtype == 'bf16': |
| dtype = torch.bfloat16 |
| logging.info("compute dtype is {}".format(dtype)) |
|
|
| context_graph = None |
| if 'decoding-graph' in args.context_bias_mode: |
| context_graph = ContextGraph(args.context_list_path, |
| tokenizer.symbol_table, |
| configs['tokenizer_conf']['bpe_path'], |
| args.context_graph_score) |
|
|
| _, blank_id = get_blank_id(configs, tokenizer.symbol_table) |
| logging.info("blank_id is {}".format(blank_id)) |
|
|
| |
| |
| |
| files = {} |
| for mode in args.modes: |
| dir_name = os.path.join(args.result_dir, mode) |
| os.makedirs(dir_name, exist_ok=True) |
| file_name = os.path.join(dir_name, 'text') |
| files[mode] = open(file_name, 'w') |
| max_format_len = max([len(mode) for mode in args.modes]) |
|
|
| with torch.cuda.amp.autocast(enabled=True, |
| dtype=dtype, |
| cache_enabled=False): |
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(test_data_loader): |
| keys = batch["keys"] |
| feats = batch["feats"].to(device) |
| target = batch["target"].to(device) |
| feats_lengths = batch["feats_lengths"].to(device) |
| target_lengths = batch["target_lengths"].to(device) |
| infos = {"tasks": batch["tasks"], "langs": batch["langs"]} |
| results = model.decode( |
| args.modes, |
| feats, |
| feats_lengths, |
| args.beam_size, |
| decoding_chunk_size=args.decoding_chunk_size, |
| num_decoding_left_chunks=args.num_decoding_left_chunks, |
| ctc_weight=args.ctc_weight, |
| simulate_streaming=args.simulate_streaming, |
| reverse_weight=args.reverse_weight, |
| context_graph=context_graph, |
| blank_id=blank_id, |
| blank_penalty=args.blank_penalty, |
| length_penalty=args.length_penalty, |
| infos=infos) |
| for i, key in enumerate(keys): |
| for mode, hyps in results.items(): |
| tokens = hyps[i].tokens |
| line = '{} {}'.format(key, |
| tokenizer.detokenize(tokens)[0]) |
| logging.info('{} {}'.format(mode.ljust(max_format_len), |
| line)) |
| files[mode].write(line + '\n') |
| for mode, f in files.items(): |
| f.close() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|