| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import print_function |
|
|
| import argparse |
| import copy |
| import logging |
| import os |
| import sys |
|
|
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
|
|
| from wenet.dataset.dataset import Dataset |
| from wenet.utils.common import IGNORE_ID |
| from wenet.utils.file_utils import read_symbol_table |
| from wenet.utils.config import override_config |
|
|
| import onnxruntime as rt |
| import multiprocessing |
| import numpy as np |
|
|
| try: |
| from swig_decoders import ( |
| map_batch, |
| ctc_beam_search_decoder_batch, |
| TrieVector, |
| PathTrie, |
| ) |
| except ImportError: |
| print( |
| "Please install ctc decoders first by refering to\n" |
| + "https://github.com/Slyne/ctc_decoder.git" |
| ) |
| sys.exit(1) |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="recognize with your model") |
| parser.add_argument("--config", required=True, help="config file") |
| parser.add_argument("--test_data", required=True, help="test data file") |
| parser.add_argument( |
| "--data_type", |
| default="raw", |
| choices=["raw", "shard"], |
| help="train and cv data type", |
| ) |
| parser.add_argument( |
| "--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
| ) |
| parser.add_argument("--dict", required=True, help="dict file") |
| parser.add_argument("--encoder_onnx", required=True, help="encoder onnx file") |
| parser.add_argument("--decoder_onnx", required=True, help="decoder onnx file") |
| parser.add_argument("--result_file", required=True, help="asr result file") |
| parser.add_argument("--batch_size", type=int, default=32, help="asr result file") |
| parser.add_argument( |
| "--mode", |
| choices=["ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring"], |
| default="attention_rescoring", |
| help="decoding mode", |
| ) |
| parser.add_argument( |
| "--bpe_model", default=None, type=str, help="bpe model for english part" |
| ) |
| parser.add_argument( |
| "--override_config", action="append", default=[], help="override yaml config" |
| ) |
| parser.add_argument( |
| "--fp16", |
| action="store_true", |
| help="whether to export fp16 model, default false", |
| ) |
| args = parser.parse_args() |
| print(args) |
| return args |
|
|
|
|
| def main(): |
| args = get_args() |
| logging.basicConfig( |
| level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
| ) |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
| with open(args.config, "r") as fin: |
| configs = yaml.load(fin, Loader=yaml.FullLoader) |
| if len(args.override_config) > 0: |
| configs = override_config(configs, args.override_config) |
|
|
| reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) |
| symbol_table = read_symbol_table(args.dict) |
| test_conf = copy.deepcopy(configs["dataset_conf"]) |
| test_conf["filter_conf"]["max_length"] = 102400 |
| test_conf["filter_conf"]["min_length"] = 0 |
| test_conf["filter_conf"]["token_max_length"] = 102400 |
| test_conf["filter_conf"]["token_min_length"] = 0 |
| test_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
| test_conf["filter_conf"]["min_output_input_ratio"] = 0 |
| test_conf["speed_perturb"] = False |
| test_conf["spec_aug"] = False |
| test_conf["spec_sub"] = False |
| test_conf["spec_trim"] = False |
| test_conf["shuffle"] = False |
| test_conf["sort"] = False |
| test_conf["fbank_conf"]["dither"] = 0.0 |
| test_conf["batch_conf"]["batch_type"] = "static" |
| test_conf["batch_conf"]["batch_size"] = args.batch_size |
|
|
| test_dataset = Dataset( |
| args.data_type, |
| args.test_data, |
| symbol_table, |
| test_conf, |
| args.bpe_model, |
| partition=False, |
| ) |
|
|
| test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) |
|
|
| |
| use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
| if use_cuda: |
| EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| else: |
| EP_list = ["CPUExecutionProvider"] |
|
|
| encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) |
| decoder_ort_session = None |
| if args.mode == "attention_rescoring": |
| decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) |
|
|
| |
| vocabulary = [] |
| char_dict = {} |
| with open(args.dict, "r") as fin: |
| for line in fin: |
| arr = line.strip().split() |
| assert len(arr) == 2 |
| char_dict[int(arr[1])] = arr[0] |
| vocabulary.append(arr[0]) |
| eos = sos = len(char_dict) - 1 |
| with torch.no_grad(), open(args.result_file, "w") as fout: |
| for _, batch in enumerate(test_data_loader): |
| keys, feats, _, feats_lengths, _ = batch |
| feats, feats_lengths = feats.numpy(), feats_lengths.numpy() |
| if args.fp16: |
| feats = feats.astype(np.float16) |
| ort_inputs = { |
| encoder_ort_session.get_inputs()[0].name: feats, |
| encoder_ort_session.get_inputs()[1].name: feats_lengths, |
| } |
| ort_outs = encoder_ort_session.run(None, ort_inputs) |
| ( |
| encoder_out, |
| encoder_out_lens, |
| ctc_log_probs, |
| beam_log_probs, |
| beam_log_probs_idx, |
| ) = ort_outs |
| beam_size = beam_log_probs.shape[-1] |
| batch_size = beam_log_probs.shape[0] |
| num_processes = min(multiprocessing.cpu_count(), batch_size) |
| if args.mode == "ctc_greedy_search": |
| if beam_size != 1: |
| log_probs_idx = beam_log_probs_idx[:, :, 0] |
| batch_sents = [] |
| for idx, seq in enumerate(log_probs_idx): |
| batch_sents.append(seq[0 : encoder_out_lens[idx]].tolist()) |
| hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) |
| elif args.mode in ("ctc_prefix_beam_search", "attention_rescoring"): |
| batch_log_probs_seq_list = beam_log_probs.tolist() |
| batch_log_probs_idx_list = beam_log_probs_idx.tolist() |
| batch_len_list = encoder_out_lens.tolist() |
| batch_log_probs_seq = [] |
| batch_log_probs_ids = [] |
| batch_start = [] |
| batch_root = TrieVector() |
| root_dict = {} |
| for i in range(len(batch_len_list)): |
| num_sent = batch_len_list[i] |
| batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) |
| batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) |
| root_dict[i] = PathTrie() |
| batch_root.append(root_dict[i]) |
| batch_start.append(True) |
| score_hyps = ctc_beam_search_decoder_batch( |
| batch_log_probs_seq, |
| batch_log_probs_ids, |
| batch_root, |
| batch_start, |
| beam_size, |
| num_processes, |
| 0, |
| -2, |
| 0.99999, |
| ) |
| if args.mode == "ctc_prefix_beam_search": |
| hyps = [] |
| for cand_hyps in score_hyps: |
| hyps.append(cand_hyps[0][1]) |
| hyps = map_batch(hyps, vocabulary, num_processes, False, 0) |
| if args.mode == "attention_rescoring": |
| ctc_score, all_hyps = [], [] |
| max_len = 0 |
| for hyps in score_hyps: |
| cur_len = len(hyps) |
| if len(hyps) < beam_size: |
| hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] |
| cur_ctc_score = [] |
| for hyp in hyps: |
| cur_ctc_score.append(hyp[0]) |
| all_hyps.append(list(hyp[1])) |
| if len(hyp[1]) > max_len: |
| max_len = len(hyp[1]) |
| ctc_score.append(cur_ctc_score) |
| if args.fp16: |
| ctc_score = np.array(ctc_score, dtype=np.float16) |
| else: |
| ctc_score = np.array(ctc_score, dtype=np.float32) |
| hyps_pad_sos_eos = ( |
| np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
| * IGNORE_ID |
| ) |
| r_hyps_pad_sos_eos = ( |
| np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
| * IGNORE_ID |
| ) |
| hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) |
| k = 0 |
| for i in range(batch_size): |
| for j in range(beam_size): |
| cand = all_hyps[k] |
| l = len(cand) + 2 |
| hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] |
| r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] |
| hyps_lens_sos[i][j] = len(cand) + 1 |
| k += 1 |
| decoder_ort_inputs = { |
| decoder_ort_session.get_inputs()[0].name: encoder_out, |
| decoder_ort_session.get_inputs()[1].name: encoder_out_lens, |
| decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, |
| decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, |
| decoder_ort_session.get_inputs()[-1].name: ctc_score, |
| } |
| if reverse_weight > 0: |
| r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name |
| decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos |
| best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] |
| best_sents = [] |
| k = 0 |
| for idx in best_index: |
| cur_best_sent = all_hyps[k : k + beam_size][idx] |
| best_sents.append(cur_best_sent) |
| k += beam_size |
| hyps = map_batch(best_sents, vocabulary, num_processes) |
|
|
| for i, key in enumerate(keys): |
| content = hyps[i] |
| logging.info("{} {}".format(key, content)) |
| fout.write("{} {}\n".format(key, content)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|