# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import argparse import copy import logging import os import sys import tarfile import urllib.request import numpy as np import torch import torch.nn.functional as F import yaml from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import BaseEncoder from wenet.utils.init_model import init_model from wenet.utils.mask import make_pad_mask from typing import List, Tuple try: import onnx import onnxruntime from onnx import helper, numpy_helper from onnxsim import simplify except ImportError: print("Please install onnxruntime!") sys.exit(1) logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) DEFAULT_PRETRAINED_MODEL_URL = ( "https://huggingface.co/openspeech/wenet-models/resolve/main/" "aishell_u2pp_conformer_exp.tar.gz") DEFAULT_PRETRAINED_MODEL_DIR = "pretrained/aishell_u2pp_conformer_exp" def safe_extract_tar(tar, output_dir): output_dir = os.path.abspath(output_dir) for member in tar.getmembers(): member_path = os.path.abspath(os.path.join(output_dir, member.name)) if not member_path.startswith(output_dir + os.sep): raise RuntimeError(f"Unsafe tar member path: {member.name}") tar.extractall(output_dir) def download_file(url, output_path): os.makedirs(os.path.dirname(output_path), exist_ok=True) print(f"Downloading pretrained model from {url}") print(f"Saving to {output_path}") urllib.request.urlretrieve(url, output_path) def prepare_pretrained_model(args): model_dir = args.pretrained_model_dir archive_dir = os.path.dirname(model_dir.rstrip(os.sep)) or "." archive_path = os.path.join( archive_dir, os.path.basename(model_dir.rstrip(os.sep)) + ".tar.gz") if not os.path.exists(model_dir): if not os.path.exists(archive_path): download_file(args.pretrained_model_url, archive_path) print(f"Extracting pretrained model to {archive_dir}") with tarfile.open(archive_path, "r:gz") as tar: safe_extract_tar(tar, archive_dir) args.config = os.path.join(model_dir, "train.yaml") args.checkpoint = os.path.join(model_dir, "final.pt") args.cmvn_file = os.path.join(model_dir, "global_cmvn") missing = [path for path in (args.config, args.checkpoint) if not os.path.exists(path)] if missing: raise FileNotFoundError( "Missing pretrained model files: " + ", ".join(missing)) print(f"Using config: {args.config}") print(f"Using checkpoint: {args.checkpoint}") if os.path.exists(args.cmvn_file): print(f"Using CMVN: {args.cmvn_file}") def _constant_node_value(node): if node is None or node.op_type != "Constant": return None for attr in node.attribute: if attr.name == "value": return numpy_helper.to_array(attr.t) return None def _attribute_value(attr): return helper.get_attribute_value(attr) def _get_attr(node, name, default=None): for attr in node.attribute: if attr.name == name: return _attribute_value(attr) return default def _cast_to_onnx_dtype(value, to_dtype): tensor_type = onnx.TensorProto.DataType.Name(to_dtype).lower() dtype_map = { "float": np.float32, "double": np.float64, "float16": np.float16, "int64": np.int64, "int32": np.int32, "int16": np.int16, "int8": np.int8, "uint64": np.uint64, "uint32": np.uint32, "uint16": np.uint16, "uint8": np.uint8, "bool": np.bool_, } if tensor_type not in dtype_map: return None return value.astype(dtype_map[tensor_type]) def _shape_from_value_info(value_info): if not value_info.type.HasField("tensor_type"): return None if not value_info.type.tensor_type.shape.dim: return None shape = [] for dim in value_info.type.tensor_type.shape.dim: if dim.HasField("dim_value") and dim.dim_value > 0: shape.append(dim.dim_value) else: return None return tuple(shape) def _collect_static_shapes(model): inferred = onnx.shape_inference.infer_shapes(model) shapes = {} for value_info in list(inferred.graph.input) + list( inferred.graph.value_info) + list(inferred.graph.output): shape = _shape_from_value_info(value_info) if shape is not None: shapes[value_info.name] = shape for initializer in inferred.graph.initializer: shapes[initializer.name] = tuple(initializer.dims) return shapes def _eval_static_node(node, inputs, static_shapes): if node.op_type == "Constant": return _constant_node_value(node) if node.op_type != "Shape" and any(value is None for value in inputs): return None try: if node.op_type == "Add": return np.add(inputs[0], inputs[1]) if node.op_type == "Sub": return np.subtract(inputs[0], inputs[1]) if node.op_type == "Mul": return np.multiply(inputs[0], inputs[1]) if node.op_type == "Div": return np.divide(inputs[0], inputs[1]) if node.op_type == "Equal": return np.equal(inputs[0], inputs[1]) if node.op_type == "Greater": return np.greater(inputs[0], inputs[1]) if node.op_type == "GreaterOrEqual": return np.greater_equal(inputs[0], inputs[1]) if node.op_type == "Less": return np.less(inputs[0], inputs[1]) if node.op_type == "LessOrEqual": return np.less_equal(inputs[0], inputs[1]) if node.op_type == "Where": return np.where(inputs[0], inputs[1], inputs[2]) if node.op_type == "Concat": axis = _get_attr(node, "axis", 0) return np.concatenate(inputs, axis=axis) if node.op_type == "Unsqueeze": axes = _get_attr(node, "axes") if axes is None and len(inputs) > 1: axes = inputs[1] axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1)) return np.expand_dims(inputs[0], axis=axes) if node.op_type == "Squeeze": axes = _get_attr(node, "axes") if axes is None and len(inputs) > 1: axes = inputs[1] if axes is None: return np.squeeze(inputs[0]) axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1)) return np.squeeze(inputs[0], axis=axes) if node.op_type == "Cast": return _cast_to_onnx_dtype(inputs[0], _get_attr(node, "to")) if node.op_type == "Reshape": return np.reshape(inputs[0], tuple(int(i) for i in inputs[1])) if node.op_type == "Shape": if inputs[0] is not None: shape = inputs[0].shape else: shape = static_shapes.get(node.input[0]) if shape is None: return None return np.asarray(shape, dtype=np.int64) if node.op_type == "Slice": data = inputs[0] starts = np.asarray(inputs[1]).reshape(-1) ends = np.asarray(inputs[2]).reshape(-1) axes = (np.asarray(inputs[3]).reshape(-1) if len(inputs) > 3 and inputs[3] is not None else np.arange(len(starts))) steps = (np.asarray(inputs[4]).reshape(-1) if len(inputs) > 4 and inputs[4] is not None else np.ones(len(starts), dtype=np.int64)) slices = [slice(None)] * data.ndim for start, end, axis, step in zip(starts, ends, axes, steps): axis = int(axis) start = int(start) end = int(end) step = int(step) if end >= np.iinfo(np.int32).max: end = None if end <= np.iinfo(np.int32).min: end = None slices[axis] = slice(start, end, step) return data[tuple(slices)] if node.op_type == "Gather": axis = _get_attr(node, "axis", 0) return np.take(inputs[0], inputs[1], axis=axis) except Exception: return None return None def _constant_node(output_name, value, name): const_tensor = numpy_helper.from_array(np.asarray(value), name=output_name + "_value") return helper.make_node("Constant", inputs=[], outputs=[output_name], name=name, value=const_tensor) def _node_attributes(node): return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute} def _copy_node(node, inputs=None, outputs=None, name=None): copied = copy.deepcopy(node) if inputs is not None: del copied.input[:] copied.input.extend(inputs) if outputs is not None: del copied.output[:] copied.output.extend(outputs) if name is not None: copied.name = name return copied def _producer_map(model): return {output: node for node in model.graph.node for output in node.output} def _unsqueeze_greater_equal_pattern(producer, value_name): unsqueeze = producer.get(value_name) if unsqueeze is None or unsqueeze.op_type != "Unsqueeze": return None, None compare = producer.get(unsqueeze.input[0]) if compare is None or compare.op_type != "GreaterOrEqual": return None, None return unsqueeze, compare def rewrite_pulsar2_bool_not(onnx_path): """Remove simple Not nodes that Pulsar2 quantization can cast to float. The encoder mask contains Not(Unsqueeze(GreaterOrEqual(...))) and another Not over a sliced version of that mask. Pulsar2 can quantize the Not input to FP32 and then fail because bitwise Not only accepts bool/integer tensors. Rewriting those patterns keeps the graph boolean-equivalent without Not. """ model = onnx.load(onnx_path) producer = _producer_map(model) rewritten = 0 new_nodes = [] for node in model.graph.node: if node.op_type != "Not": new_nodes.append(node) continue compare = producer.get(node.input[0]) if compare is not None and compare.op_type == "GreaterOrEqual": less = helper.make_node("Less", inputs=list(compare.input), outputs=list(node.output), name=node.name + "_less", **_node_attributes(compare)) new_nodes.append(less) rewritten += 1 continue unsqueeze, compare = _unsqueeze_greater_equal_pattern( producer, node.input[0]) if unsqueeze is not None: less_output = node.output[0] + "_less" less = helper.make_node("Less", inputs=list(compare.input), outputs=[less_output], name=node.name + "_less", **_node_attributes(compare)) rewritten_unsqueeze = _copy_node( unsqueeze, inputs=[less_output] + list(unsqueeze.input[1:]), outputs=list(node.output), name=node.name + "_unsqueeze") new_nodes.extend([less, rewritten_unsqueeze]) rewritten += 1 continue slice_1 = producer.get(node.input[0]) slice_0 = producer.get(slice_1.input[0]) if slice_1 else None inner_not = producer.get(slice_0.input[0]) if slice_0 else None if (slice_1 is not None and slice_1.op_type == "Slice" and slice_0 is not None and slice_0.op_type == "Slice" and inner_not is not None and inner_not.op_type == "Not"): unsqueeze, _ = _unsqueeze_greater_equal_pattern( producer, inner_not.input[0]) if unsqueeze is not None: slice_0_output = node.output[0] + "_slice0" rewritten_slice_0 = _copy_node( slice_0, inputs=[unsqueeze.output[0]] + list(slice_0.input[1:]), outputs=[slice_0_output], name=node.name + "_slice0") rewritten_slice_1 = _copy_node( slice_1, inputs=[slice_0_output] + list(slice_1.input[1:]), outputs=list(node.output), name=node.name + "_slice1") new_nodes.extend([rewritten_slice_0, rewritten_slice_1]) rewritten += 1 continue new_nodes.append(node) if rewritten: del model.graph.node[:] model.graph.node.extend(new_nodes) onnx.checker.check_model(model) onnx.save(model, onnx_path) print(f"Rewrote {rewritten} bool Not node(s) in {onnx_path}") def rewrite_pulsar2_bool_and(onnx_path): """Replace boolean And with arithmetic comparison for Pulsar2 quantization.""" model = onnx.load(onnx_path) rewritten = 0 new_nodes = [] for node in model.graph.node: if node.op_type != "And" or len(node.input) != 2 or len( node.output) != 1: new_nodes.append(node) continue left = node.output[0] + "_left_i32" right = node.output[0] + "_right_i32" added = node.output[0] + "_sum" threshold = node.output[0] + "_threshold" new_nodes.append( helper.make_node("Cast", inputs=[node.input[0]], outputs=[left], name=node.name + "_cast_left", to=onnx.TensorProto.INT32)) new_nodes.append( helper.make_node("Cast", inputs=[node.input[1]], outputs=[right], name=node.name + "_cast_right", to=onnx.TensorProto.INT32)) new_nodes.append( helper.make_node("Add", inputs=[left, right], outputs=[added], name=node.name + "_add")) new_nodes.append( _constant_node(threshold, np.asarray(1, dtype=np.int32), node.name + "_threshold")) new_nodes.append( helper.make_node("Greater", inputs=[added, threshold], outputs=list(node.output), name=node.name + "_greater")) rewritten += 1 if rewritten: del model.graph.node[:] model.graph.node.extend(new_nodes) onnx.checker.check_model(model) onnx.save(model, onnx_path) print(f"Rewrote {rewritten} bool And node(s) in {onnx_path}") def simplify_pulsar2_onnx(onnx_path): model = onnx.load(onnx_path) sim_model, ok = simplify(model) if not ok: raise RuntimeError(f"onnxsim failed to validate {onnx_path}") onnx.checker.check_model(sim_model) onnx.save(sim_model, onnx_path) print(f"Simplified {onnx_path} for Pulsar2") def fold_static_pulsar2_subgraphs(onnx_path): """Fold static ONNX patterns that Pulsar2 5.0 cannot infer reliably. Pulsar2 5.0 can fail shape inference on ConstantOfShape when its input is a constant tensor value instead of an initializer. The legacy exporter emits this pattern for masks/padding in the encoder graphs. It can also fail when an Expand shape is produced by a constant-only subgraph such as Mul/Equal/Where. Fold those static pieces before handing the model to Pulsar2. """ model = onnx.load(onnx_path) static_shapes = _collect_static_shapes(model) constants = { initializer.name: numpy_helper.to_array(initializer) for initializer in model.graph.initializer } folded = 0 new_nodes = [] for node in model.graph.node: inputs = [constants.get(name) for name in node.input] if node.op_type == "ConstantOfShape" and node.input: shape_value = inputs[0] if shape_value is not None: fill_value = np.array(0, dtype=np.float32) for attr in node.attribute: if attr.name == "value": fill_value = numpy_helper.to_array(attr.t) break shape = tuple(int(dim) for dim in np.asarray(shape_value).reshape(-1)) value = np.full(shape, fill_value.reshape(-1)[0], dtype=fill_value.dtype) else: value = None else: value = _eval_static_node(node, inputs, static_shapes) if value is None or len(node.output) != 1: new_nodes.append(node) continue constants[node.output[0]] = value new_nodes.append(_constant_node(node.output[0], value, node.name)) folded += 1 if folded: del model.graph.node[:] model.graph.node.extend(new_nodes) onnx.checker.check_model(model) onnx.save(model, onnx_path) print(f"Folded {folded} static node(s) in {onnx_path}") class Encoder(torch.nn.Module): def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): super().__init__() self.encoder = encoder self.ctc = ctc self.beam_size = beam_size def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, ): """Encoder Args: speech: (Batch, Length, ...) speech_lengths: (Batch, ) Returns: encoder_out: B x T x F encoder_out_lens: B ctc_log_probs: B x T x V beam_log_probs: B x T x beam_size beam_log_probs_idx: B x T x beam_size """ encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) # ctc_log_probs = self.ctc.log_softmax(encoder_out) ctc_log_probs = self.ctc.linear(encoder_out) encoder_out_lens = encoder_out_lens.int() beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, self.beam_size, dim=2) return ( encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx, ) class StreamingEncoder(torch.nn.Module): def __init__( self, model, required_cache_size, beam_size, transformer=False, return_ctc_logprobs=False, ): super().__init__() self.ctc = model.ctc self.subsampling_rate = model.encoder.embed.subsampling_rate self.embed = model.encoder.embed self.global_cmvn = model.encoder.global_cmvn self.required_cache_size = required_cache_size self.beam_size = beam_size self.encoder = model.encoder self.transformer = transformer self.return_ctc_logprobs = return_ctc_logprobs def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): """Streaming Encoder Args: xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), where `time == (chunk_size - 1) * subsample_rate + \ subsample.right_context + 1` offset (torch.Tensor): offset with shape (b, 1) 1 is retained for triton deployment required_cache_size (int): cache size required for next chunk compuation > 0: actual cache size <= 0: not allowed in streaming gpu encoder ` att_cache (torch.Tensor): cache tensor for KEY & VALUE in transformer/conformer attention, with shape (b, elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, (b, elayers, b, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1` cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) in a batch of request, each request may have different history cache. Cache mask is used to indidate the effective cache for each request Returns: torch.Tensor: log probabilities of ctc output and cutoff by beam size with shape (b, chunk_size, beam) torch.Tensor: index of top beam size probabilities for each timestep with shape (b, chunk_size, beam) torch.Tensor: output of current input xs, with shape (b, chunk_size, hidden-dim). torch.Tensor: new attention cache required for next chunk, with same shape (b, elayers, head, cache_t1, d_k * 2) as the original att_cache torch.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. torch.Tensor: new cache mask, with same shape as the original cache mask """ offset = offset.squeeze(1) T = chunk_xs.size(1) chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # B X 1 X T chunk_mask = chunk_mask.to(chunk_xs.dtype) # transpose batch & num_layers dim att_cache = torch.transpose(att_cache, 0, 1) cnn_cache = torch.transpose(cnn_cache, 0, 1) # rewrite encoder.forward_chunk # <---------forward_chunk START---------> xs = self.global_cmvn(chunk_xs) # chunk mask is important for batch inferencing since # different sequence in a batch has different length xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) cache_size = att_cache.size(3) # required cache size masks = torch.cat((cache_mask, chunk_mask), dim=2) index = offset - cache_size pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) pos_emb = pos_emb.to(dtype=xs.dtype) next_cache_start = -self.required_cache_size r_cache_mask = masks[:, :, next_cache_start:] r_att_cache = [] r_cnn_cache = [] for i, layer in enumerate(self.encoder.encoders): i_kv_cache = att_cache[i] size = att_cache.size(-1) // 2 kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:]) xs, _, new_kv_cache, new_cnn_cache = layer( xs, masks, pos_emb, att_cache=kv_cache, cnn_cache=cnn_cache[i], ) # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) new_att_cache = torch.cat(new_kv_cache, dim=-1) r_att_cache.append( new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) if not self.transformer: r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) if self.encoder.normalize_before: chunk_out = self.encoder.after_norm(xs) else: chunk_out = xs r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx if not self.transformer: r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers # <---------forward_chunk END---------> # log_ctc_probs = self.ctc.log_softmax(chunk_out) log_ctc_probs = self.ctc.linear(chunk_out) log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] # the below ops not supported in Tensorrt # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, # rounding_mode='floor') chunk_out_lens = chunk_lens // self.subsampling_rate r_offset = r_offset.unsqueeze(1) if self.return_ctc_logprobs: return ( log_ctc_probs, chunk_out, chunk_out_lens, r_offset, r_att_cache, r_cnn_cache, r_cache_mask, ) else: return ( log_probs, log_probs_idx, chunk_out, chunk_out_lens, r_offset, r_att_cache, r_cnn_cache, r_cache_mask, ) class StreamingSqueezeformerEncoder(torch.nn.Module): def __init__(self, model, required_cache_size, beam_size): super().__init__() self.ctc = model.ctc self.subsampling_rate = model.encoder.embed.subsampling_rate self.embed = model.encoder.embed self.global_cmvn = model.encoder.global_cmvn self.required_cache_size = required_cache_size self.beam_size = beam_size self.encoder = model.encoder self.reduce_idx = model.encoder.reduce_idx self.recover_idx = model.encoder.recover_idx if self.reduce_idx is None: self.time_reduce = None else: if self.recover_idx is None: self.time_reduce = "normal" # no recovery at the end else: self.time_reduce = "recover" # recovery at the end assert len(self.reduce_idx) == len(self.recover_idx) def calculate_downsampling_factor(self, i: int) -> int: if self.reduce_idx is None: return 1 else: reduce_exp, recover_exp = 0, 0 for exp, rd_idx in enumerate(self.reduce_idx): if i >= rd_idx: reduce_exp = exp + 1 if self.recover_idx is not None: for exp, rc_idx in enumerate(self.recover_idx): if i >= rc_idx: recover_exp = exp + 1 return int(2**(reduce_exp - recover_exp)) def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): """Streaming Encoder Args: xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), where `time == (chunk_size - 1) * subsample_rate + \ subsample.right_context + 1` offset (torch.Tensor): offset with shape (b, 1) 1 is retained for triton deployment required_cache_size (int): cache size required for next chunk compuation > 0: actual cache size <= 0: not allowed in streaming gpu encoder ` att_cache (torch.Tensor): cache tensor for KEY & VALUE in transformer/conformer attention, with shape (b, elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, (b, elayers, b, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1` cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) in a batch of request, each request may have different history cache. Cache mask is used to indidate the effective cache for each request Returns: torch.Tensor: log probabilities of ctc output and cutoff by beam size with shape (b, chunk_size, beam) torch.Tensor: index of top beam size probabilities for each timestep with shape (b, chunk_size, beam) torch.Tensor: output of current input xs, with shape (b, chunk_size, hidden-dim). torch.Tensor: new attention cache required for next chunk, with same shape (b, elayers, head, cache_t1, d_k * 2) as the original att_cache torch.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. torch.Tensor: new cache mask, with same shape as the original cache mask """ offset = offset.squeeze(1) T = chunk_xs.size(1) chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # B X 1 X T chunk_mask = chunk_mask.to(chunk_xs.dtype) # transpose batch & num_layers dim att_cache = torch.transpose(att_cache, 0, 1) cnn_cache = torch.transpose(cnn_cache, 0, 1) # rewrite encoder.forward_chunk # <---------forward_chunk START---------> xs = self.global_cmvn(chunk_xs) # chunk mask is important for batch inferencing since # different sequence in a batch has different length xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) elayers, cache_size = att_cache.size(0), att_cache.size(3) att_mask = torch.cat((cache_mask, chunk_mask), dim=2) index = offset - cache_size pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) pos_emb = pos_emb.to(dtype=xs.dtype) next_cache_start = -self.required_cache_size r_cache_mask = att_mask[:, :, next_cache_start:] r_att_cache = [] r_cnn_cache = [] mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) mask_pad = mask_pad.unsqueeze(1) max_att_len: int = 0 recover_activations: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] index = 0 xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) xs = self.encoder.preln(xs) for i, layer in enumerate(self.encoder.encoders): if self.reduce_idx is not None: if self.time_reduce is not None and i in self.reduce_idx: recover_activations.append( (xs, att_mask, pos_emb, mask_pad)) ( xs, xs_lens, att_mask, mask_pad, ) = self.encoder.time_reduction_layer( xs, xs_lens, att_mask, mask_pad) pos_emb = pos_emb[:, ::2, :] if self.encoder.pos_enc_layer_type == "rel_pos_repaired": pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :] index += 1 if self.recover_idx is not None: if self.time_reduce == "recover" and i in self.recover_idx: index -= 1 ( recover_tensor, recover_att_mask, recover_pos_emb, recover_mask_pad, ) = recover_activations[index] # recover output length for ctc decode xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) xs = self.encoder.time_recover_layer(xs) recoverd_t = recover_tensor.size(1) xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() att_mask = recover_att_mask pos_emb = recover_pos_emb mask_pad = recover_mask_pad factor = self.calculate_downsampling_factor(i) xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, att_cache=att_cache[i][:, :, ::factor, :] [:, :, :pos_emb.size(1) - xs.size(1), :] if elayers > 0 else att_cache[:, :, ::factor, :], cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, ) cached_att = new_att_cache[:, :, next_cache_start // factor:, :] cached_cnn = new_cnn_cache.unsqueeze(1) cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3)) if i == 0: # record length for the first block as max length max_att_len = cached_att.size(2) r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1)) r_cnn_cache.append(cached_cnn) chunk_out = xs r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers # <---------forward_chunk END---------> # log_ctc_probs = self.ctc.log_softmax(chunk_out) log_ctc_probs = self.ctc.linear(chunk_out) log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] # the below ops not supported in Tensorrt # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, # rounding_mode='floor') chunk_out_lens = chunk_lens // self.subsampling_rate r_offset = r_offset.unsqueeze(1) return ( log_probs, log_probs_idx, chunk_out, chunk_out_lens, r_offset, r_att_cache, r_cnn_cache, r_cache_mask, ) class StreamingEfficientConformerEncoder(torch.nn.Module): def __init__(self, model, required_cache_size, beam_size): super().__init__() self.ctc = model.ctc self.subsampling_rate = model.encoder.embed.subsampling_rate self.embed = model.encoder.embed self.global_cmvn = model.encoder.global_cmvn self.required_cache_size = required_cache_size self.beam_size = beam_size self.encoder = model.encoder # Efficient Conformer self.stride_layer_idx = model.encoder.stride_layer_idx self.stride = model.encoder.stride self.num_blocks = model.encoder.num_blocks self.cnn_module_kernel = model.encoder.cnn_module_kernel def calculate_downsampling_factor(self, i: int) -> int: factor = 1 for idx, stride_idx in enumerate(self.stride_layer_idx): if i > stride_idx: factor *= self.stride[idx] return factor def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask): """Streaming Encoder Args: chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), where `time == (chunk_size - 1) * subsample_rate + \ subsample.right_context + 1` chunk_lens (torch.Tensor): offset (torch.Tensor): offset with shape (b, 1) 1 is retained for triton deployment att_cache (torch.Tensor): cache tensor for KEY & VALUE in transformer/conformer attention, with shape (b, elayers, head, cache_t1, d_k * 2), where `head * d_k == hidden-dim` and `cache_t1 == chunk_size * num_decoding_left_chunks`. cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, (b, elayers, hidden-dim, cache_t2), where `cache_t2 == cnn.lorder - 1` cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) in a batch of request, each request may have different history cache. Cache mask is used to indidate the effective cache for each request Returns: torch.Tensor: log probabilities of ctc output and cutoff by beam size with shape (b, chunk_size, beam) torch.Tensor: index of top beam size probabilities for each timestep with shape (b, chunk_size, beam) torch.Tensor: output of current input xs, with shape (b, chunk_size, hidden-dim). torch.Tensor: new attention cache required for next chunk, with same shape (b, elayers, head, cache_t1, d_k * 2) as the original att_cache torch.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. torch.Tensor: new cache mask, with same shape as the original cache mask """ offset = offset.squeeze(1) # (b, ) offset *= self.calculate_downsampling_factor(self.num_blocks + 1) T = chunk_xs.size(1) chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T) # B X 1 X T chunk_mask = chunk_mask.to(chunk_xs.dtype) # transpose batch & num_layers dim # Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2) # Shape(cnn_cache): (elayers, b, outsize, cnn_kernel) att_cache = torch.transpose(att_cache, 0, 1) cnn_cache = torch.transpose(cnn_cache, 0, 1) # rewrite encoder.forward_chunk # <---------forward_chunk START---------> xs = self.global_cmvn(chunk_xs) # chunk mask is important for batch inferencing since # different sequence in a batch has different length xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) cache_size = att_cache.size(3) # required cache size masks = torch.cat((cache_mask, chunk_mask), dim=2) att_mask = torch.cat((cache_mask, chunk_mask), dim=2) index = offset - cache_size pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) pos_emb = pos_emb.to(dtype=xs.dtype) next_cache_start = -self.required_cache_size r_cache_mask = masks[:, :, next_cache_start:] r_att_cache = [] r_cnn_cache = [] mask_pad = chunk_mask.to(torch.bool) max_att_len, max_cnn_len = ( 0, 0, ) # for repeat_interleave of new_att_cache for i, layer in enumerate(self.encoder.encoders): factor = self.calculate_downsampling_factor(i) # NOTE(xcsong): Before layer.forward # shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2), # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) # shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ] att_cache_trunc = 0 if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1): # The time step is not divisible by the downsampling multiple # We propose to double the chunk_size. att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor - pos_emb.size(1) + 1) xs, _, new_att_cache, new_cnn_cache = layer( xs, att_mask, pos_emb, mask_pad=mask_pad, att_cache=att_cache[i][:, :, ::factor, :][:, :, att_cache_trunc:, :], cnn_cache=cnn_cache[i, :, :, :] if cnn_cache.size(0) > 0 else cnn_cache, ) if i in self.stride_layer_idx: # compute time dimension for next block efficient_index = self.stride_layer_idx.index(i) att_mask = att_mask[:, ::self.stride[efficient_index], ::self. stride[efficient_index], ] mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self. stride[efficient_index], ] pos_emb = pos_emb[:, ::self.stride[efficient_index], :] # shape(new_att_cache) = [batch, head, time2, outdim] new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :] # shape(new_cnn_cache) = [batch, 1, outdim, cache_t2] new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID # use repeat_interleave to new_att_cache # new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2) new_att_cache = (new_att_cache.unsqueeze(3).repeat( 1, 1, 1, factor, 1).flatten(2, 3)) # padding new_cnn_cache to cnn.lorder for casual convolution new_cnn_cache = F.pad( new_cnn_cache, (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0), ) if i == 0: # record length for the first block as max length max_att_len = new_att_cache.size(2) max_cnn_len = new_cnn_cache.size(3) # update real shape of att_cache and cnn_cache r_att_cache.append(new_att_cache[:, :, -max_att_len:, :].unsqueeze(1)) r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:]) if self.encoder.normalize_before: chunk_out = self.encoder.after_norm(xs) else: chunk_out = xs # shape of r_att_cache: (b, elayers, head, time2, outdim) r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx # shape of r_cnn_cache: (b, elayers, outdim, cache_t2) r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers # <---------forward_chunk END---------> # log_ctc_probs = self.ctc.log_softmax(chunk_out) log_ctc_probs = self.ctc.linear(chunk_out) log_probs, log_probs_idx = torch.topk(log_ctc_probs, self.beam_size, dim=2) log_probs = log_probs.to(chunk_xs.dtype) r_offset = offset + chunk_out.shape[1] # the below ops not supported in Tensorrt # chunk_out_lens = torch.div(chunk_lens, subsampling_rate, # rounding_mode='floor') chunk_out_lens = ( chunk_lens // self.subsampling_rate // self.calculate_downsampling_factor(self.num_blocks + 1)) chunk_out_lens += 1 r_offset = r_offset.unsqueeze(1) return ( log_probs, log_probs_idx, chunk_out, chunk_out_lens, r_offset, r_att_cache, r_cnn_cache, r_cache_mask, ) class Decoder(torch.nn.Module): def __init__( self, decoder: TransformerDecoder, ctc_weight: float = 0.5, reverse_weight: float = 0.0, beam_size: int = 10, decoder_fastertransformer: bool = False, ): super().__init__() self.decoder = decoder self.ctc_weight = ctc_weight self.reverse_weight = reverse_weight self.beam_size = beam_size self.decoder_fastertransformer = decoder_fastertransformer def forward( self, encoder_out: torch.Tensor, encoder_lens: torch.Tensor, hyps_pad_sos_eos: torch.Tensor, hyps_lens_sos: torch.Tensor, r_hyps_pad_sos_eos: torch.Tensor, ctc_score: torch.Tensor, ): """Encoder Args: encoder_out: B x T x F encoder_lens: B hyps_pad_sos_eos: B x beam x (T2+1), hyps with sos & eos and padded by ignore id hyps_lens_sos: B x beam, length for each hyp with sos r_hyps_pad_sos_eos: B x beam x (T2+1), reversed hyps with sos & eos and padded by ignore id ctc_score: B x beam, ctc score for each hyp Returns: decoder_out: B x beam x T2 x V r_decoder_out: B x beam x T2 x V best_index: B """ B, T, F = encoder_out.shape bz = self.beam_size B2 = B * bz encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) T2 = hyps_pad_sos_eos.shape[2] - 1 hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) hyps_lens = hyps_lens_sos.view(B2, ) hyps_pad_sos = hyps_pad[:, :-1].contiguous() hyps_pad_eos = hyps_pad[:, 1:].contiguous() r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos, self.reverse_weight, ) # decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) V = decoder_out.shape[-1] decoder_out = decoder_out.view(B2, T2, V) mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 # mask index, remove ignore id index = torch.unsqueeze(hyps_pad_eos * mask, 2).to(torch.long) score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 # mask padded part score = score * mask decoder_out = decoder_out.view(B, bz, T2, V) if self.reverse_weight > 0: # r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, # dim=-1) r_decoder_out = r_decoder_out.view(B2, T2, V) index = torch.unsqueeze(r_hyps_pad_eos * mask, 2).to(torch.long) r_score = r_decoder_out.gather(2, index).squeeze(2) r_score = r_score * mask score = (score * (1 - self.reverse_weight) + self.reverse_weight * r_score) r_decoder_out = r_decoder_out.view(B, bz, T2, V) score = torch.sum(score, axis=1) # B2 score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score best_index = torch.argmax(score, dim=1) if self.decoder_fastertransformer: return decoder_out, best_index else: return best_index def to_numpy(tensors): out = [] if type(tensors) == torch.tensor: tensors = [tensors] for tensor in tensors: if tensor.requires_grad: tensor = tensor.detach().cpu().numpy() else: tensor = tensor.cpu().numpy() out.append(tensor) return out def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): for a, b in zip(xlist, blist): try: torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) except AssertionError as error: if tolerate_small_mismatch: print(error) else: raise def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): bz = 1 seq_len = 1024 beam_size = args.beam_size feature_size = configs["input_dim"] speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) speech_lens = torch.randint(low=10, high=seq_len, size=(bz, ), dtype=torch.int32) encoder = Encoder(model.encoder, model.ctc, beam_size) encoder.eval() torch.onnx.export( encoder, (speech, speech_lens), encoder_onnx_path, export_params=True, opset_version=13, do_constant_folding=True, input_names=["speech", "speech_lengths"], output_names=[ "encoder_out", "encoder_out_lens", "ctc_log_probs", "beam_log_probs", "beam_log_probs_idx", ], dynamic_axes=None, # dynamic_axes={ # "speech": { # 0: "B", # 1: "T" # }, # "speech_lengths": { # 0: "B" # }, # "encoder_out": { # 0: "B", # 1: "T_OUT" # }, # "encoder_out_lens": { # 0: "B" # }, # "ctc_log_probs": { # 0: "B", # 1: "T_OUT" # }, # "beam_log_probs": { # 0: "B", # 1: "T_OUT" # }, # "beam_log_probs_idx": { # 0: "B", # 1: "T_OUT" # }, # }, verbose=False, dynamo=False, ) fold_static_pulsar2_subgraphs(encoder_onnx_path) simplify_pulsar2_onnx(encoder_onnx_path) rewrite_pulsar2_bool_not(encoder_onnx_path) with torch.no_grad(): o0, o1, o2, o3, o4 = encoder(speech, speech_lens) providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(encoder_onnx_path, providers=providers) ort_inputs = { "speech": to_numpy(speech), "speech_lengths": to_numpy(speech_lens), } ort_outs = ort_session.run(None, ort_inputs) # check encoder output test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) logger.info("export offline onnx encoder succeed!") onnx_config = { "beam_size": args.beam_size, "reverse_weight": configs["model_conf"]["reverse_weight"], "ctc_weight": configs["model_conf"]["ctc_weight"], } return onnx_config def export_online_encoder(model, configs, args, logger, encoder_onnx_path): decoding_chunk_size = args.decoding_chunk_size subsampling = model.encoder.embed.subsampling_rate context = model.encoder.embed.right_context + 1 decoding_window = (decoding_chunk_size - 1) * subsampling + context batch_size = 1 audio_len = decoding_window feature_size = configs["input_dim"] output_size = configs["encoder_conf"]["output_size"] num_layers = configs["encoder_conf"]["num_blocks"] # in transformer the cnn module will not be available transformer = False cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 if not cnn_module_kernel: transformer = True num_decoding_left_chunks = args.num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks if configs["encoder"] == "squeezeformer": encoder = StreamingSqueezeformerEncoder(model, required_cache_size, args.beam_size) elif configs["encoder"] == "efficientConformer": encoder = StreamingEfficientConformerEncoder(model, required_cache_size, args.beam_size) else: encoder = StreamingEncoder( model, required_cache_size, args.beam_size, transformer, args.return_ctc_logprobs, ) encoder.eval() # begin to export encoder chunk_xs = torch.randn(batch_size, audio_len, feature_size, dtype=torch.float32) chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len offset = torch.arange(0, batch_size, dtype=torch.int32).unsqueeze(1) # (elayers, b, head, cache_t1, d_k * 2) head = configs["encoder_conf"]["attention_heads"] d_k = configs["encoder_conf"]["output_size"] // head att_cache = torch.randn( batch_size, num_layers, head, required_cache_size, d_k * 2, dtype=torch.float32, ) cnn_cache = torch.randn( batch_size, num_layers, output_size, cnn_module_kernel, dtype=torch.float32, ) cache_mask = torch.ones(batch_size, 1, required_cache_size, dtype=torch.float32) input_names = [ "chunk_xs", "chunk_lens", "offset", "att_cache", "cnn_cache", "cache_mask", ] output_names = [ "log_probs", "log_probs_idx", "chunk_out", "chunk_out_lens", "r_offset", "r_att_cache", "r_cnn_cache", "r_cache_mask", ] if args.return_ctc_logprobs: output_names = [ "ctc_log_probs", "chunk_out", "chunk_out_lens", "r_offset", "r_att_cache", "r_cnn_cache", "r_cache_mask", ] input_tensors = ( chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask, ) if transformer: assert (args.return_ctc_logprobs is False), "return_ctc_logprobs is not supported in transformer" output_names.pop(6) all_names = input_names + output_names dynamic_axes = {} for name in all_names: # only the first dimension is dynamic # all other dimension is fixed dynamic_axes[name] = {0: "B"} torch.onnx.export( encoder, input_tensors, encoder_onnx_path, export_params=True, opset_version=14, do_constant_folding=True, input_names=input_names, output_names=output_names, # dynamic_axes=dynamic_axes, dynamic_axes=None, verbose=False, dynamo=False, ) fold_static_pulsar2_subgraphs(encoder_onnx_path) simplify_pulsar2_onnx(encoder_onnx_path) rewrite_pulsar2_bool_not(encoder_onnx_path) with torch.no_grad(): torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache, cnn_cache, cache_mask) if transformer: torch_outs = list(torch_outs).pop(6) ort_session = onnxruntime.InferenceSession( encoder_onnx_path, providers=["CPUExecutionProvider"]) ort_inputs = {} input_tensors = to_numpy(input_tensors) for idx, name in enumerate(input_names): ort_inputs[name] = input_tensors[idx] if transformer: del ort_inputs["cnn_cache"] ort_outs = ort_session.run(None, ort_inputs) test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05) logger.info("export to onnx streaming encoder succeed!") onnx_config = { "subsampling_rate": subsampling, "context": context, "decoding_chunk_size": decoding_chunk_size, "num_decoding_left_chunks": num_decoding_left_chunks, "beam_size": args.beam_size, "feat_size": feature_size, "decoding_window": decoding_window, "cnn_module_kernel_cache": cnn_module_kernel, "return_ctc_logprobs": args.return_ctc_logprobs, } return onnx_config def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, decoder_fastertransformer): bz, seq_len = 1, 32 beam_size = args.beam_size decoder = Decoder( model.decoder, model.ctc_weight, model.reverse_weight, beam_size, decoder_fastertransformer, ) decoder.eval() hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len), dtype=torch.int32) hyps_lens_sos = torch.randint(low=3, high=seq_len, size=(bz, beam_size), dtype=torch.int32) r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len), dtype=torch.int32) output_size = configs["encoder_conf"]["output_size"] encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz, ), dtype=torch.int32) ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) input_names = [ "encoder_out", "encoder_out_lens", "hyps_pad_sos_eos", "hyps_lens_sos", "r_hyps_pad_sos_eos", "ctc_score", ] output_names = ["best_index"] if decoder_fastertransformer: output_names.insert(0, "decoder_out") torch.onnx.export( decoder, ( encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score, ), decoder_onnx_path, export_params=True, opset_version=13, do_constant_folding=True, input_names=input_names, output_names=output_names, dynamic_axes=None, # dynamic_axes={ # "encoder_out": { # 0: "B", # 1: "T" # }, # "encoder_out_lens": { # 0: "B" # }, # "hyps_pad_sos_eos": { # 0: "B", # 2: "T2" # }, # "hyps_lens_sos": { # 0: "B" # }, # "r_hyps_pad_sos_eos": { # 0: "B", # 2: "T2" # }, # "ctc_score": { # 0: "B" # }, # "best_index": { # 0: "B" # }, # }, verbose=False, dynamo=False, ) fold_static_pulsar2_subgraphs(decoder_onnx_path) simplify_pulsar2_onnx(decoder_onnx_path) rewrite_pulsar2_bool_not(decoder_onnx_path) rewrite_pulsar2_bool_and(decoder_onnx_path) with torch.no_grad(): o0 = decoder( encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score, ) providers = ["CPUExecutionProvider"] ort_session = onnxruntime.InferenceSession(decoder_onnx_path, providers=providers) input_tensors = [ encoder_out, encoder_out_lens, hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_score, ] ort_inputs = {} input_tensors = to_numpy(input_tensors) for idx, name in enumerate(input_names): ort_inputs[name] = input_tensors[idx] # if model.reverse weight == 0, # the r_hyps_pad will be removed # from the onnx decoder since it doen't play any role if model.reverse_weight == 0: del ort_inputs["r_hyps_pad_sos_eos"] ort_outs = ort_session.run(None, ort_inputs) # check decoder output if decoder_fastertransformer: test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) else: test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) logger.info("export to onnx decoder succeed!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="export x86_gpu model") parser.add_argument( "--pretrained_model_dir", default=DEFAULT_PRETRAINED_MODEL_DIR, help=("pretrained model directory containing train.yaml, final.pt, " "and global_cmvn"), ) parser.add_argument( "--pretrained_model_url", default=DEFAULT_PRETRAINED_MODEL_URL, help="pretrained model tar.gz URL used when pretrained_model_dir is missing", ) parser.add_argument( "--reverse_weight", default=-1.0, type=float, required=False, help="reverse weight for bitransformer," + "default value is in config file", ) parser.add_argument( "--ctc_weight", default=-1.0, type=float, required=False, help="ctc weight, default value is in config file", ) parser.add_argument( "--beam_size", default=10, type=int, required=False, help="beam size would be ctc output size", ) parser.add_argument( "--output_onnx_dir", default="onnx_model", help="output onnx encoder and decoder directory", ) # arguments for streaming encoder # parser.add_argument( # "--streaming", # action="store_true", # help="whether to export streaming encoder, default false", # ) parser.add_argument( "--decoding_chunk_size", default=16, type=int, required=False, help="the decoding chunk size, <=0 is not supported", ) parser.add_argument( "--num_decoding_left_chunks", default=5, type=int, required=False, help="number of left chunks, <= 0 is not supported", ) parser.add_argument( "--decoder_fastertransformer", action="store_true", help="return decoder_out and best_index for ft", ) parser.add_argument( "--return_ctc_logprobs", action="store_true", help="return full ctc_log_probs for TLG streaming encoder", ) args = parser.parse_args() prepare_pretrained_model(args) torch.manual_seed(0) torch.set_printoptions(precision=10) with open(args.config, "r") as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) if os.path.exists(args.cmvn_file): if 'cmvn' not in configs: configs['cmvn'] = "global_cmvn" configs['cmvn_conf'] = {} else: assert configs['cmvn'] == "global_cmvn" assert configs['cmvn_conf'] is not None configs['cmvn_conf']["cmvn_file"] = args.cmvn_file configs['cmvn_conf'].setdefault( "is_json_cmvn", configs.get("is_json_cmvn", True)) elif configs.get('cmvn', None) == 'global_cmvn': raise FileNotFoundError( f"Expected global_cmvn in pretrained model dir: {args.cmvn_file}") if (args.reverse_weight != -1.0 and "reverse_weight" in configs["model_conf"]): configs["model_conf"]["reverse_weight"] = args.reverse_weight print("Update reverse weight to", args.reverse_weight) if args.ctc_weight != -1: print("Update ctc weight to ", args.ctc_weight) configs["model_conf"]["ctc_weight"] = args.ctc_weight configs["encoder_conf"]["use_dynamic_chunk"] = False model, configs = init_model(args, configs) model.eval() if not os.path.exists(args.output_onnx_dir): os.mkdir(args.output_onnx_dir) export_enc_func = None # if args.streaming: assert args.decoding_chunk_size > 0 assert args.num_decoding_left_chunks > 0 export_enc_func = export_online_encoder encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_online.onnx") onnx_config = export_enc_func(model, configs, args, logger, encoder_onnx_path) # else export_enc_func = export_offline_encoder encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_offline.onnx") onnx_config = export_enc_func(model, configs, args, logger, encoder_onnx_path) decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx") export_rescoring_decoder( model, configs, args, logger, decoder_onnx_path, args.decoder_fastertransformer, ) config_dir = os.path.join(args.output_onnx_dir, "config.yaml") with open(config_dir, "w") as out: yaml.dump(onnx_config, out)