| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import print_function |
| |
|
| | import argparse |
| | import copy |
| | import logging |
| | import os |
| | import sys |
| |
|
| | import torch |
| | import yaml |
| | from torch.utils.data import DataLoader |
| |
|
| | from wenet.dataset.dataset import Dataset |
| | from wenet.utils.common import IGNORE_ID |
| | from wenet.utils.file_utils import read_symbol_table |
| | from wenet.utils.config import override_config |
| |
|
| | import onnxruntime as rt |
| | import multiprocessing |
| | import numpy as np |
| |
|
| | try: |
| | from swig_decoders import ( |
| | map_batch, |
| | ctc_beam_search_decoder_batch, |
| | TrieVector, |
| | PathTrie, |
| | ) |
| | except ImportError: |
| | print( |
| | "Please install ctc decoders first by refering to\n" |
| | + "https://github.com/Slyne/ctc_decoder.git" |
| | ) |
| | sys.exit(1) |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser(description="recognize with your model") |
| | parser.add_argument("--config", required=True, help="config file") |
| | parser.add_argument("--test_data", required=True, help="test data file") |
| | parser.add_argument( |
| | "--data_type", |
| | default="raw", |
| | choices=["raw", "shard"], |
| | help="train and cv data type", |
| | ) |
| | parser.add_argument( |
| | "--gpu", type=int, default=-1, help="gpu id for this rank, -1 for cpu" |
| | ) |
| | parser.add_argument("--dict", required=True, help="dict file") |
| | parser.add_argument("--encoder_onnx", required=True, help="encoder onnx file") |
| | parser.add_argument("--decoder_onnx", required=True, help="decoder onnx file") |
| | parser.add_argument("--result_file", required=True, help="asr result file") |
| | parser.add_argument("--batch_size", type=int, default=32, help="asr result file") |
| | parser.add_argument( |
| | "--mode", |
| | choices=["ctc_greedy_search", "ctc_prefix_beam_search", "attention_rescoring"], |
| | default="attention_rescoring", |
| | help="decoding mode", |
| | ) |
| | parser.add_argument( |
| | "--bpe_model", default=None, type=str, help="bpe model for english part" |
| | ) |
| | parser.add_argument( |
| | "--override_config", action="append", default=[], help="override yaml config" |
| | ) |
| | parser.add_argument( |
| | "--fp16", |
| | action="store_true", |
| | help="whether to export fp16 model, default false", |
| | ) |
| | args = parser.parse_args() |
| | print(args) |
| | return args |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| | logging.basicConfig( |
| | level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s" |
| | ) |
| | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
| |
|
| | with open(args.config, "r") as fin: |
| | configs = yaml.load(fin, Loader=yaml.FullLoader) |
| | if len(args.override_config) > 0: |
| | configs = override_config(configs, args.override_config) |
| |
|
| | reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) |
| | symbol_table = read_symbol_table(args.dict) |
| | test_conf = copy.deepcopy(configs["dataset_conf"]) |
| | test_conf["filter_conf"]["max_length"] = 102400 |
| | test_conf["filter_conf"]["min_length"] = 0 |
| | test_conf["filter_conf"]["token_max_length"] = 102400 |
| | test_conf["filter_conf"]["token_min_length"] = 0 |
| | test_conf["filter_conf"]["max_output_input_ratio"] = 102400 |
| | test_conf["filter_conf"]["min_output_input_ratio"] = 0 |
| | test_conf["speed_perturb"] = False |
| | test_conf["spec_aug"] = False |
| | test_conf["spec_sub"] = False |
| | test_conf["spec_trim"] = False |
| | test_conf["shuffle"] = False |
| | test_conf["sort"] = False |
| | test_conf["fbank_conf"]["dither"] = 0.0 |
| | test_conf["batch_conf"]["batch_type"] = "static" |
| | test_conf["batch_conf"]["batch_size"] = args.batch_size |
| |
|
| | test_dataset = Dataset( |
| | args.data_type, |
| | args.test_data, |
| | symbol_table, |
| | test_conf, |
| | args.bpe_model, |
| | partition=False, |
| | ) |
| |
|
| | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) |
| |
|
| | |
| | use_cuda = args.gpu >= 0 and torch.cuda.is_available() |
| | if use_cuda: |
| | EP_list = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| | else: |
| | EP_list = ["CPUExecutionProvider"] |
| |
|
| | encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) |
| | decoder_ort_session = None |
| | if args.mode == "attention_rescoring": |
| | decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) |
| |
|
| | |
| | vocabulary = [] |
| | char_dict = {} |
| | with open(args.dict, "r") as fin: |
| | for line in fin: |
| | arr = line.strip().split() |
| | assert len(arr) == 2 |
| | char_dict[int(arr[1])] = arr[0] |
| | vocabulary.append(arr[0]) |
| | eos = sos = len(char_dict) - 1 |
| | with torch.no_grad(), open(args.result_file, "w") as fout: |
| | for _, batch in enumerate(test_data_loader): |
| | keys, feats, _, feats_lengths, _ = batch |
| | feats, feats_lengths = feats.numpy(), feats_lengths.numpy() |
| | if args.fp16: |
| | feats = feats.astype(np.float16) |
| | ort_inputs = { |
| | encoder_ort_session.get_inputs()[0].name: feats, |
| | encoder_ort_session.get_inputs()[1].name: feats_lengths, |
| | } |
| | ort_outs = encoder_ort_session.run(None, ort_inputs) |
| | ( |
| | encoder_out, |
| | encoder_out_lens, |
| | ctc_log_probs, |
| | beam_log_probs, |
| | beam_log_probs_idx, |
| | ) = ort_outs |
| | beam_size = beam_log_probs.shape[-1] |
| | batch_size = beam_log_probs.shape[0] |
| | num_processes = min(multiprocessing.cpu_count(), batch_size) |
| | if args.mode == "ctc_greedy_search": |
| | if beam_size != 1: |
| | log_probs_idx = beam_log_probs_idx[:, :, 0] |
| | batch_sents = [] |
| | for idx, seq in enumerate(log_probs_idx): |
| | batch_sents.append(seq[0 : encoder_out_lens[idx]].tolist()) |
| | hyps = map_batch(batch_sents, vocabulary, num_processes, True, 0) |
| | elif args.mode in ("ctc_prefix_beam_search", "attention_rescoring"): |
| | batch_log_probs_seq_list = beam_log_probs.tolist() |
| | batch_log_probs_idx_list = beam_log_probs_idx.tolist() |
| | batch_len_list = encoder_out_lens.tolist() |
| | batch_log_probs_seq = [] |
| | batch_log_probs_ids = [] |
| | batch_start = [] |
| | batch_root = TrieVector() |
| | root_dict = {} |
| | for i in range(len(batch_len_list)): |
| | num_sent = batch_len_list[i] |
| | batch_log_probs_seq.append(batch_log_probs_seq_list[i][0:num_sent]) |
| | batch_log_probs_ids.append(batch_log_probs_idx_list[i][0:num_sent]) |
| | root_dict[i] = PathTrie() |
| | batch_root.append(root_dict[i]) |
| | batch_start.append(True) |
| | score_hyps = ctc_beam_search_decoder_batch( |
| | batch_log_probs_seq, |
| | batch_log_probs_ids, |
| | batch_root, |
| | batch_start, |
| | beam_size, |
| | num_processes, |
| | 0, |
| | -2, |
| | 0.99999, |
| | ) |
| | if args.mode == "ctc_prefix_beam_search": |
| | hyps = [] |
| | for cand_hyps in score_hyps: |
| | hyps.append(cand_hyps[0][1]) |
| | hyps = map_batch(hyps, vocabulary, num_processes, False, 0) |
| | if args.mode == "attention_rescoring": |
| | ctc_score, all_hyps = [], [] |
| | max_len = 0 |
| | for hyps in score_hyps: |
| | cur_len = len(hyps) |
| | if len(hyps) < beam_size: |
| | hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] |
| | cur_ctc_score = [] |
| | for hyp in hyps: |
| | cur_ctc_score.append(hyp[0]) |
| | all_hyps.append(list(hyp[1])) |
| | if len(hyp[1]) > max_len: |
| | max_len = len(hyp[1]) |
| | ctc_score.append(cur_ctc_score) |
| | if args.fp16: |
| | ctc_score = np.array(ctc_score, dtype=np.float16) |
| | else: |
| | ctc_score = np.array(ctc_score, dtype=np.float32) |
| | hyps_pad_sos_eos = ( |
| | np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
| | * IGNORE_ID |
| | ) |
| | r_hyps_pad_sos_eos = ( |
| | np.ones((batch_size, beam_size, max_len + 2), dtype=np.int64) |
| | * IGNORE_ID |
| | ) |
| | hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) |
| | k = 0 |
| | for i in range(batch_size): |
| | for j in range(beam_size): |
| | cand = all_hyps[k] |
| | l = len(cand) + 2 |
| | hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] |
| | r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] |
| | hyps_lens_sos[i][j] = len(cand) + 1 |
| | k += 1 |
| | decoder_ort_inputs = { |
| | decoder_ort_session.get_inputs()[0].name: encoder_out, |
| | decoder_ort_session.get_inputs()[1].name: encoder_out_lens, |
| | decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, |
| | decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, |
| | decoder_ort_session.get_inputs()[-1].name: ctc_score, |
| | } |
| | if reverse_weight > 0: |
| | r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name |
| | decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos |
| | best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] |
| | best_sents = [] |
| | k = 0 |
| | for idx in best_index: |
| | cur_best_sent = all_hyps[k : k + beam_size][idx] |
| | best_sents.append(cur_best_sent) |
| | k += beam_size |
| | hyps = map_batch(best_sents, vocabulary, num_processes) |
| |
|
| | for i, key in enumerate(keys): |
| | content = hyps[i] |
| | logging.info("{} {}".format(key, content)) |
| | fout.write("{} {}\n".format(key, content)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|