| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import print_function |
|
|
| import argparse |
| import copy |
| import logging |
| import os |
| import sys |
| import tarfile |
| import urllib.request |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import yaml |
| from wenet.transformer.ctc import CTC |
| from wenet.transformer.decoder import TransformerDecoder |
| from wenet.transformer.encoder import BaseEncoder |
| from wenet.utils.init_model import init_model |
| from wenet.utils.mask import make_pad_mask |
| from typing import List, Tuple |
|
|
| try: |
| import onnx |
| import onnxruntime |
| from onnx import helper, numpy_helper |
| from onnxsim import simplify |
| except ImportError: |
| print("Please install onnxruntime!") |
| sys.exit(1) |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(logging.INFO) |
|
|
| DEFAULT_PRETRAINED_MODEL_URL = ( |
| "https://huggingface.co/openspeech/wenet-models/resolve/main/" |
| "aishell_u2pp_conformer_exp.tar.gz") |
| DEFAULT_PRETRAINED_MODEL_DIR = "pretrained/aishell_u2pp_conformer_exp" |
|
|
|
|
| def safe_extract_tar(tar, output_dir): |
| output_dir = os.path.abspath(output_dir) |
| for member in tar.getmembers(): |
| member_path = os.path.abspath(os.path.join(output_dir, member.name)) |
| if not member_path.startswith(output_dir + os.sep): |
| raise RuntimeError(f"Unsafe tar member path: {member.name}") |
| tar.extractall(output_dir) |
|
|
|
|
| def download_file(url, output_path): |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| print(f"Downloading pretrained model from {url}") |
| print(f"Saving to {output_path}") |
| urllib.request.urlretrieve(url, output_path) |
|
|
|
|
| def prepare_pretrained_model(args): |
| model_dir = args.pretrained_model_dir |
| archive_dir = os.path.dirname(model_dir.rstrip(os.sep)) or "." |
| archive_path = os.path.join( |
| archive_dir, os.path.basename(model_dir.rstrip(os.sep)) + ".tar.gz") |
|
|
| if not os.path.exists(model_dir): |
| if not os.path.exists(archive_path): |
| download_file(args.pretrained_model_url, archive_path) |
| print(f"Extracting pretrained model to {archive_dir}") |
| with tarfile.open(archive_path, "r:gz") as tar: |
| safe_extract_tar(tar, archive_dir) |
|
|
| args.config = os.path.join(model_dir, "train.yaml") |
| args.checkpoint = os.path.join(model_dir, "final.pt") |
| args.cmvn_file = os.path.join(model_dir, "global_cmvn") |
|
|
| missing = [path for path in (args.config, args.checkpoint) |
| if not os.path.exists(path)] |
| if missing: |
| raise FileNotFoundError( |
| "Missing pretrained model files: " + ", ".join(missing)) |
|
|
| print(f"Using config: {args.config}") |
| print(f"Using checkpoint: {args.checkpoint}") |
| if os.path.exists(args.cmvn_file): |
| print(f"Using CMVN: {args.cmvn_file}") |
|
|
|
|
| def _constant_node_value(node): |
| if node is None or node.op_type != "Constant": |
| return None |
| for attr in node.attribute: |
| if attr.name == "value": |
| return numpy_helper.to_array(attr.t) |
| return None |
|
|
|
|
| def _attribute_value(attr): |
| return helper.get_attribute_value(attr) |
|
|
|
|
| def _get_attr(node, name, default=None): |
| for attr in node.attribute: |
| if attr.name == name: |
| return _attribute_value(attr) |
| return default |
|
|
|
|
| def _cast_to_onnx_dtype(value, to_dtype): |
| tensor_type = onnx.TensorProto.DataType.Name(to_dtype).lower() |
| dtype_map = { |
| "float": np.float32, |
| "double": np.float64, |
| "float16": np.float16, |
| "int64": np.int64, |
| "int32": np.int32, |
| "int16": np.int16, |
| "int8": np.int8, |
| "uint64": np.uint64, |
| "uint32": np.uint32, |
| "uint16": np.uint16, |
| "uint8": np.uint8, |
| "bool": np.bool_, |
| } |
| if tensor_type not in dtype_map: |
| return None |
| return value.astype(dtype_map[tensor_type]) |
|
|
|
|
| def _shape_from_value_info(value_info): |
| if not value_info.type.HasField("tensor_type"): |
| return None |
| if not value_info.type.tensor_type.shape.dim: |
| return None |
| shape = [] |
| for dim in value_info.type.tensor_type.shape.dim: |
| if dim.HasField("dim_value") and dim.dim_value > 0: |
| shape.append(dim.dim_value) |
| else: |
| return None |
| return tuple(shape) |
|
|
|
|
| def _collect_static_shapes(model): |
| inferred = onnx.shape_inference.infer_shapes(model) |
| shapes = {} |
| for value_info in list(inferred.graph.input) + list( |
| inferred.graph.value_info) + list(inferred.graph.output): |
| shape = _shape_from_value_info(value_info) |
| if shape is not None: |
| shapes[value_info.name] = shape |
| for initializer in inferred.graph.initializer: |
| shapes[initializer.name] = tuple(initializer.dims) |
| return shapes |
|
|
|
|
| def _eval_static_node(node, inputs, static_shapes): |
| if node.op_type == "Constant": |
| return _constant_node_value(node) |
| if node.op_type != "Shape" and any(value is None for value in inputs): |
| return None |
|
|
| try: |
| if node.op_type == "Add": |
| return np.add(inputs[0], inputs[1]) |
| if node.op_type == "Sub": |
| return np.subtract(inputs[0], inputs[1]) |
| if node.op_type == "Mul": |
| return np.multiply(inputs[0], inputs[1]) |
| if node.op_type == "Div": |
| return np.divide(inputs[0], inputs[1]) |
| if node.op_type == "Equal": |
| return np.equal(inputs[0], inputs[1]) |
| if node.op_type == "Greater": |
| return np.greater(inputs[0], inputs[1]) |
| if node.op_type == "GreaterOrEqual": |
| return np.greater_equal(inputs[0], inputs[1]) |
| if node.op_type == "Less": |
| return np.less(inputs[0], inputs[1]) |
| if node.op_type == "LessOrEqual": |
| return np.less_equal(inputs[0], inputs[1]) |
| if node.op_type == "Where": |
| return np.where(inputs[0], inputs[1], inputs[2]) |
| if node.op_type == "Concat": |
| axis = _get_attr(node, "axis", 0) |
| return np.concatenate(inputs, axis=axis) |
| if node.op_type == "Unsqueeze": |
| axes = _get_attr(node, "axes") |
| if axes is None and len(inputs) > 1: |
| axes = inputs[1] |
| axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1)) |
| return np.expand_dims(inputs[0], axis=axes) |
| if node.op_type == "Squeeze": |
| axes = _get_attr(node, "axes") |
| if axes is None and len(inputs) > 1: |
| axes = inputs[1] |
| if axes is None: |
| return np.squeeze(inputs[0]) |
| axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1)) |
| return np.squeeze(inputs[0], axis=axes) |
| if node.op_type == "Cast": |
| return _cast_to_onnx_dtype(inputs[0], _get_attr(node, "to")) |
| if node.op_type == "Reshape": |
| return np.reshape(inputs[0], tuple(int(i) for i in inputs[1])) |
| if node.op_type == "Shape": |
| if inputs[0] is not None: |
| shape = inputs[0].shape |
| else: |
| shape = static_shapes.get(node.input[0]) |
| if shape is None: |
| return None |
| return np.asarray(shape, dtype=np.int64) |
| if node.op_type == "Slice": |
| data = inputs[0] |
| starts = np.asarray(inputs[1]).reshape(-1) |
| ends = np.asarray(inputs[2]).reshape(-1) |
| axes = (np.asarray(inputs[3]).reshape(-1) |
| if len(inputs) > 3 and inputs[3] is not None else |
| np.arange(len(starts))) |
| steps = (np.asarray(inputs[4]).reshape(-1) |
| if len(inputs) > 4 and inputs[4] is not None else |
| np.ones(len(starts), dtype=np.int64)) |
| slices = [slice(None)] * data.ndim |
| for start, end, axis, step in zip(starts, ends, axes, steps): |
| axis = int(axis) |
| start = int(start) |
| end = int(end) |
| step = int(step) |
| if end >= np.iinfo(np.int32).max: |
| end = None |
| if end <= np.iinfo(np.int32).min: |
| end = None |
| slices[axis] = slice(start, end, step) |
| return data[tuple(slices)] |
| if node.op_type == "Gather": |
| axis = _get_attr(node, "axis", 0) |
| return np.take(inputs[0], inputs[1], axis=axis) |
| except Exception: |
| return None |
| return None |
|
|
|
|
| def _constant_node(output_name, value, name): |
| const_tensor = numpy_helper.from_array(np.asarray(value), |
| name=output_name + "_value") |
| return helper.make_node("Constant", |
| inputs=[], |
| outputs=[output_name], |
| name=name, |
| value=const_tensor) |
|
|
|
|
| def _node_attributes(node): |
| return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute} |
|
|
|
|
| def _copy_node(node, inputs=None, outputs=None, name=None): |
| copied = copy.deepcopy(node) |
| if inputs is not None: |
| del copied.input[:] |
| copied.input.extend(inputs) |
| if outputs is not None: |
| del copied.output[:] |
| copied.output.extend(outputs) |
| if name is not None: |
| copied.name = name |
| return copied |
|
|
|
|
| def _producer_map(model): |
| return {output: node for node in model.graph.node for output in node.output} |
|
|
|
|
| def _unsqueeze_greater_equal_pattern(producer, value_name): |
| unsqueeze = producer.get(value_name) |
| if unsqueeze is None or unsqueeze.op_type != "Unsqueeze": |
| return None, None |
| compare = producer.get(unsqueeze.input[0]) |
| if compare is None or compare.op_type != "GreaterOrEqual": |
| return None, None |
| return unsqueeze, compare |
|
|
|
|
| def rewrite_pulsar2_bool_not(onnx_path): |
| """Remove simple Not nodes that Pulsar2 quantization can cast to float. |
| |
| The encoder mask contains Not(Unsqueeze(GreaterOrEqual(...))) and another |
| Not over a sliced version of that mask. Pulsar2 can quantize the Not input |
| to FP32 and then fail because bitwise Not only accepts bool/integer tensors. |
| Rewriting those patterns keeps the graph boolean-equivalent without Not. |
| """ |
| model = onnx.load(onnx_path) |
| producer = _producer_map(model) |
| rewritten = 0 |
| new_nodes = [] |
|
|
| for node in model.graph.node: |
| if node.op_type != "Not": |
| new_nodes.append(node) |
| continue |
|
|
| compare = producer.get(node.input[0]) |
| if compare is not None and compare.op_type == "GreaterOrEqual": |
| less = helper.make_node("Less", |
| inputs=list(compare.input), |
| outputs=list(node.output), |
| name=node.name + "_less", |
| **_node_attributes(compare)) |
| new_nodes.append(less) |
| rewritten += 1 |
| continue |
|
|
| unsqueeze, compare = _unsqueeze_greater_equal_pattern( |
| producer, node.input[0]) |
| if unsqueeze is not None: |
| less_output = node.output[0] + "_less" |
| less = helper.make_node("Less", |
| inputs=list(compare.input), |
| outputs=[less_output], |
| name=node.name + "_less", |
| **_node_attributes(compare)) |
| rewritten_unsqueeze = _copy_node( |
| unsqueeze, |
| inputs=[less_output] + list(unsqueeze.input[1:]), |
| outputs=list(node.output), |
| name=node.name + "_unsqueeze") |
| new_nodes.extend([less, rewritten_unsqueeze]) |
| rewritten += 1 |
| continue |
|
|
| slice_1 = producer.get(node.input[0]) |
| slice_0 = producer.get(slice_1.input[0]) if slice_1 else None |
| inner_not = producer.get(slice_0.input[0]) if slice_0 else None |
| if (slice_1 is not None and slice_1.op_type == "Slice" |
| and slice_0 is not None and slice_0.op_type == "Slice" |
| and inner_not is not None and inner_not.op_type == "Not"): |
| unsqueeze, _ = _unsqueeze_greater_equal_pattern( |
| producer, inner_not.input[0]) |
| if unsqueeze is not None: |
| slice_0_output = node.output[0] + "_slice0" |
| rewritten_slice_0 = _copy_node( |
| slice_0, |
| inputs=[unsqueeze.output[0]] + list(slice_0.input[1:]), |
| outputs=[slice_0_output], |
| name=node.name + "_slice0") |
| rewritten_slice_1 = _copy_node( |
| slice_1, |
| inputs=[slice_0_output] + list(slice_1.input[1:]), |
| outputs=list(node.output), |
| name=node.name + "_slice1") |
| new_nodes.extend([rewritten_slice_0, rewritten_slice_1]) |
| rewritten += 1 |
| continue |
|
|
| new_nodes.append(node) |
|
|
| if rewritten: |
| del model.graph.node[:] |
| model.graph.node.extend(new_nodes) |
| onnx.checker.check_model(model) |
| onnx.save(model, onnx_path) |
| print(f"Rewrote {rewritten} bool Not node(s) in {onnx_path}") |
|
|
|
|
| def rewrite_pulsar2_bool_and(onnx_path): |
| """Replace boolean And with arithmetic comparison for Pulsar2 quantization.""" |
| model = onnx.load(onnx_path) |
| rewritten = 0 |
| new_nodes = [] |
|
|
| for node in model.graph.node: |
| if node.op_type != "And" or len(node.input) != 2 or len( |
| node.output) != 1: |
| new_nodes.append(node) |
| continue |
|
|
| left = node.output[0] + "_left_i32" |
| right = node.output[0] + "_right_i32" |
| added = node.output[0] + "_sum" |
| threshold = node.output[0] + "_threshold" |
| new_nodes.append( |
| helper.make_node("Cast", |
| inputs=[node.input[0]], |
| outputs=[left], |
| name=node.name + "_cast_left", |
| to=onnx.TensorProto.INT32)) |
| new_nodes.append( |
| helper.make_node("Cast", |
| inputs=[node.input[1]], |
| outputs=[right], |
| name=node.name + "_cast_right", |
| to=onnx.TensorProto.INT32)) |
| new_nodes.append( |
| helper.make_node("Add", |
| inputs=[left, right], |
| outputs=[added], |
| name=node.name + "_add")) |
| new_nodes.append( |
| _constant_node(threshold, np.asarray(1, dtype=np.int32), |
| node.name + "_threshold")) |
| new_nodes.append( |
| helper.make_node("Greater", |
| inputs=[added, threshold], |
| outputs=list(node.output), |
| name=node.name + "_greater")) |
| rewritten += 1 |
|
|
| if rewritten: |
| del model.graph.node[:] |
| model.graph.node.extend(new_nodes) |
| onnx.checker.check_model(model) |
| onnx.save(model, onnx_path) |
| print(f"Rewrote {rewritten} bool And node(s) in {onnx_path}") |
|
|
|
|
| def simplify_pulsar2_onnx(onnx_path): |
| model = onnx.load(onnx_path) |
| sim_model, ok = simplify(model) |
| if not ok: |
| raise RuntimeError(f"onnxsim failed to validate {onnx_path}") |
| onnx.checker.check_model(sim_model) |
| onnx.save(sim_model, onnx_path) |
| print(f"Simplified {onnx_path} for Pulsar2") |
|
|
|
|
| def fold_static_pulsar2_subgraphs(onnx_path): |
| """Fold static ONNX patterns that Pulsar2 5.0 cannot infer reliably. |
| |
| Pulsar2 5.0 can fail shape inference on ConstantOfShape when its input is |
| a constant tensor value instead of an initializer. The legacy exporter emits |
| this pattern for masks/padding in the encoder graphs. It can also fail when |
| an Expand shape is produced by a constant-only subgraph such as |
| Mul/Equal/Where. Fold those static pieces before handing the model to |
| Pulsar2. |
| """ |
| model = onnx.load(onnx_path) |
| static_shapes = _collect_static_shapes(model) |
| constants = { |
| initializer.name: numpy_helper.to_array(initializer) |
| for initializer in model.graph.initializer |
| } |
| folded = 0 |
| new_nodes = [] |
| for node in model.graph.node: |
| inputs = [constants.get(name) for name in node.input] |
| if node.op_type == "ConstantOfShape" and node.input: |
| shape_value = inputs[0] |
| if shape_value is not None: |
| fill_value = np.array(0, dtype=np.float32) |
| for attr in node.attribute: |
| if attr.name == "value": |
| fill_value = numpy_helper.to_array(attr.t) |
| break |
| shape = tuple(int(dim) |
| for dim in np.asarray(shape_value).reshape(-1)) |
| value = np.full(shape, |
| fill_value.reshape(-1)[0], |
| dtype=fill_value.dtype) |
| else: |
| value = None |
| else: |
| value = _eval_static_node(node, inputs, static_shapes) |
|
|
| if value is None or len(node.output) != 1: |
| new_nodes.append(node) |
| continue |
|
|
| constants[node.output[0]] = value |
| new_nodes.append(_constant_node(node.output[0], value, node.name)) |
| folded += 1 |
|
|
| if folded: |
| del model.graph.node[:] |
| model.graph.node.extend(new_nodes) |
| onnx.checker.check_model(model) |
| onnx.save(model, onnx_path) |
| print(f"Folded {folded} static node(s) in {onnx_path}") |
|
|
|
|
| class Encoder(torch.nn.Module): |
|
|
| def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10): |
| super().__init__() |
| self.encoder = encoder |
| self.ctc = ctc |
| self.beam_size = beam_size |
|
|
| def forward( |
| self, |
| speech: torch.Tensor, |
| speech_lengths: torch.Tensor, |
| ): |
| """Encoder |
| Args: |
| speech: (Batch, Length, ...) |
| speech_lengths: (Batch, ) |
| Returns: |
| encoder_out: B x T x F |
| encoder_out_lens: B |
| ctc_log_probs: B x T x V |
| beam_log_probs: B x T x beam_size |
| beam_log_probs_idx: B x T x beam_size |
| """ |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, |
| -1) |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) |
| |
| ctc_log_probs = self.ctc.linear(encoder_out) |
| encoder_out_lens = encoder_out_lens.int() |
| beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, |
| self.beam_size, |
| dim=2) |
| return ( |
| encoder_out, |
| encoder_out_lens, |
| ctc_log_probs, |
| beam_log_probs, |
| beam_log_probs_idx, |
| ) |
|
|
|
|
| class StreamingEncoder(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| model, |
| required_cache_size, |
| beam_size, |
| transformer=False, |
| return_ctc_logprobs=False, |
| ): |
| super().__init__() |
| self.ctc = model.ctc |
| self.subsampling_rate = model.encoder.embed.subsampling_rate |
| self.embed = model.encoder.embed |
| self.global_cmvn = model.encoder.global_cmvn |
| self.required_cache_size = required_cache_size |
| self.beam_size = beam_size |
| self.encoder = model.encoder |
| self.transformer = transformer |
| self.return_ctc_logprobs = return_ctc_logprobs |
|
|
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, |
| cache_mask): |
| """Streaming Encoder |
| Args: |
| xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), |
| where `time == (chunk_size - 1) * subsample_rate + \ |
| subsample.right_context + 1` |
| offset (torch.Tensor): offset with shape (b, 1) |
| 1 is retained for triton deployment |
| required_cache_size (int): cache size required for next chunk |
| compuation |
| > 0: actual cache size |
| <= 0: not allowed in streaming gpu encoder ` |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in |
| transformer/conformer attention, with shape |
| (b, elayers, head, cache_t1, d_k * 2), where |
| `head * d_k == hidden-dim` and |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, |
| (b, elayers, b, hidden-dim, cache_t2), where |
| `cache_t2 == cnn.lorder - 1` |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) |
| in a batch of request, each request may have different |
| history cache. Cache mask is used to indidate the effective |
| cache for each request |
| Returns: |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size |
| with shape (b, chunk_size, beam) |
| torch.Tensor: index of top beam size probabilities for each timestep |
| with shape (b, chunk_size, beam) |
| torch.Tensor: output of current input xs, |
| with shape (b, chunk_size, hidden-dim). |
| torch.Tensor: new attention cache required for next chunk, with |
| same shape (b, elayers, head, cache_t1, d_k * 2) |
| as the original att_cache |
| torch.Tensor: new conformer cnn cache required for next chunk, with |
| same shape as the original cnn_cache. |
| torch.Tensor: new cache mask, with same shape as the original |
| cache mask |
| """ |
| offset = offset.squeeze(1) |
| T = chunk_xs.size(1) |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) |
| |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) |
| |
| att_cache = torch.transpose(att_cache, 0, 1) |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) |
|
|
| |
| |
| xs = self.global_cmvn(chunk_xs) |
| |
| |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) |
| cache_size = att_cache.size(3) |
| masks = torch.cat((cache_mask, chunk_mask), dim=2) |
| index = offset - cache_size |
|
|
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) |
| pos_emb = pos_emb.to(dtype=xs.dtype) |
|
|
| next_cache_start = -self.required_cache_size |
| r_cache_mask = masks[:, :, next_cache_start:] |
|
|
| r_att_cache = [] |
| r_cnn_cache = [] |
| for i, layer in enumerate(self.encoder.encoders): |
| i_kv_cache = att_cache[i] |
| size = att_cache.size(-1) // 2 |
| kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:]) |
| xs, _, new_kv_cache, new_cnn_cache = layer( |
| xs, |
| masks, |
| pos_emb, |
| att_cache=kv_cache, |
| cnn_cache=cnn_cache[i], |
| ) |
| |
| |
| new_att_cache = torch.cat(new_kv_cache, dim=-1) |
| r_att_cache.append( |
| new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) |
| if not self.transformer: |
| r_cnn_cache.append(new_cnn_cache.unsqueeze(1)) |
| if self.encoder.normalize_before: |
| chunk_out = self.encoder.after_norm(xs) |
| else: |
| chunk_out = xs |
|
|
| r_att_cache = torch.cat(r_att_cache, dim=1) |
| if not self.transformer: |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) |
|
|
| |
|
|
| |
| log_ctc_probs = self.ctc.linear(chunk_out) |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, |
| self.beam_size, |
| dim=2) |
| log_probs = log_probs.to(chunk_xs.dtype) |
|
|
| r_offset = offset + chunk_out.shape[1] |
| |
| |
| |
| chunk_out_lens = chunk_lens // self.subsampling_rate |
| r_offset = r_offset.unsqueeze(1) |
| if self.return_ctc_logprobs: |
| return ( |
| log_ctc_probs, |
| chunk_out, |
| chunk_out_lens, |
| r_offset, |
| r_att_cache, |
| r_cnn_cache, |
| r_cache_mask, |
| ) |
| else: |
| return ( |
| log_probs, |
| log_probs_idx, |
| chunk_out, |
| chunk_out_lens, |
| r_offset, |
| r_att_cache, |
| r_cnn_cache, |
| r_cache_mask, |
| ) |
|
|
|
|
| class StreamingSqueezeformerEncoder(torch.nn.Module): |
|
|
| def __init__(self, model, required_cache_size, beam_size): |
| super().__init__() |
| self.ctc = model.ctc |
| self.subsampling_rate = model.encoder.embed.subsampling_rate |
| self.embed = model.encoder.embed |
| self.global_cmvn = model.encoder.global_cmvn |
| self.required_cache_size = required_cache_size |
| self.beam_size = beam_size |
| self.encoder = model.encoder |
| self.reduce_idx = model.encoder.reduce_idx |
| self.recover_idx = model.encoder.recover_idx |
| if self.reduce_idx is None: |
| self.time_reduce = None |
| else: |
| if self.recover_idx is None: |
| self.time_reduce = "normal" |
| else: |
| self.time_reduce = "recover" |
| assert len(self.reduce_idx) == len(self.recover_idx) |
|
|
| def calculate_downsampling_factor(self, i: int) -> int: |
| if self.reduce_idx is None: |
| return 1 |
| else: |
| reduce_exp, recover_exp = 0, 0 |
| for exp, rd_idx in enumerate(self.reduce_idx): |
| if i >= rd_idx: |
| reduce_exp = exp + 1 |
| if self.recover_idx is not None: |
| for exp, rc_idx in enumerate(self.recover_idx): |
| if i >= rc_idx: |
| recover_exp = exp + 1 |
| return int(2**(reduce_exp - recover_exp)) |
|
|
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, |
| cache_mask): |
| """Streaming Encoder |
| Args: |
| xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), |
| where `time == (chunk_size - 1) * subsample_rate + \ |
| subsample.right_context + 1` |
| offset (torch.Tensor): offset with shape (b, 1) |
| 1 is retained for triton deployment |
| required_cache_size (int): cache size required for next chunk |
| compuation |
| > 0: actual cache size |
| <= 0: not allowed in streaming gpu encoder ` |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in |
| transformer/conformer attention, with shape |
| (b, elayers, head, cache_t1, d_k * 2), where |
| `head * d_k == hidden-dim` and |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, |
| (b, elayers, b, hidden-dim, cache_t2), where |
| `cache_t2 == cnn.lorder - 1` |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) |
| in a batch of request, each request may have different |
| history cache. Cache mask is used to indidate the effective |
| cache for each request |
| Returns: |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size |
| with shape (b, chunk_size, beam) |
| torch.Tensor: index of top beam size probabilities for each timestep |
| with shape (b, chunk_size, beam) |
| torch.Tensor: output of current input xs, |
| with shape (b, chunk_size, hidden-dim). |
| torch.Tensor: new attention cache required for next chunk, with |
| same shape (b, elayers, head, cache_t1, d_k * 2) |
| as the original att_cache |
| torch.Tensor: new conformer cnn cache required for next chunk, with |
| same shape as the original cnn_cache. |
| torch.Tensor: new cache mask, with same shape as the original |
| cache mask |
| """ |
| offset = offset.squeeze(1) |
| T = chunk_xs.size(1) |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) |
| |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) |
| |
| att_cache = torch.transpose(att_cache, 0, 1) |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) |
|
|
| |
| |
| xs = self.global_cmvn(chunk_xs) |
| |
| |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) |
| elayers, cache_size = att_cache.size(0), att_cache.size(3) |
| att_mask = torch.cat((cache_mask, chunk_mask), dim=2) |
| index = offset - cache_size |
|
|
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) |
| pos_emb = pos_emb.to(dtype=xs.dtype) |
|
|
| next_cache_start = -self.required_cache_size |
| r_cache_mask = att_mask[:, :, next_cache_start:] |
|
|
| r_att_cache = [] |
| r_cnn_cache = [] |
| mask_pad = torch.ones(1, |
| xs.size(1), |
| device=xs.device, |
| dtype=torch.bool) |
| mask_pad = mask_pad.unsqueeze(1) |
| max_att_len: int = 0 |
| recover_activations: List[Tuple[torch.Tensor, torch.Tensor, |
| torch.Tensor, torch.Tensor]] = [] |
| index = 0 |
| xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) |
| xs = self.encoder.preln(xs) |
| for i, layer in enumerate(self.encoder.encoders): |
| if self.reduce_idx is not None: |
| if self.time_reduce is not None and i in self.reduce_idx: |
| recover_activations.append( |
| (xs, att_mask, pos_emb, mask_pad)) |
| ( |
| xs, |
| xs_lens, |
| att_mask, |
| mask_pad, |
| ) = self.encoder.time_reduction_layer( |
| xs, xs_lens, att_mask, mask_pad) |
| pos_emb = pos_emb[:, ::2, :] |
| if self.encoder.pos_enc_layer_type == "rel_pos_repaired": |
| pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :] |
| index += 1 |
|
|
| if self.recover_idx is not None: |
| if self.time_reduce == "recover" and i in self.recover_idx: |
| index -= 1 |
| ( |
| recover_tensor, |
| recover_att_mask, |
| recover_pos_emb, |
| recover_mask_pad, |
| ) = recover_activations[index] |
| |
| xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) |
| xs = self.encoder.time_recover_layer(xs) |
| recoverd_t = recover_tensor.size(1) |
| xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() |
| att_mask = recover_att_mask |
| pos_emb = recover_pos_emb |
| mask_pad = recover_mask_pad |
|
|
| factor = self.calculate_downsampling_factor(i) |
|
|
| xs, _, new_att_cache, new_cnn_cache = layer( |
| xs, |
| att_mask, |
| pos_emb, |
| att_cache=att_cache[i][:, :, ::factor, :] |
| [:, :, :pos_emb.size(1) - xs.size(1), :] |
| if elayers > 0 else att_cache[:, :, ::factor, :], |
| cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, |
| ) |
| cached_att = new_att_cache[:, :, next_cache_start // factor:, :] |
| cached_cnn = new_cnn_cache.unsqueeze(1) |
| cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor, |
| 1).flatten(2, 3)) |
| if i == 0: |
| |
| max_att_len = cached_att.size(2) |
| r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1)) |
| r_cnn_cache.append(cached_cnn) |
|
|
| chunk_out = xs |
| r_att_cache = torch.cat(r_att_cache, dim=1) |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) |
|
|
| |
|
|
| |
| log_ctc_probs = self.ctc.linear(chunk_out) |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, |
| self.beam_size, |
| dim=2) |
| log_probs = log_probs.to(chunk_xs.dtype) |
|
|
| r_offset = offset + chunk_out.shape[1] |
| |
| |
| |
| chunk_out_lens = chunk_lens // self.subsampling_rate |
| r_offset = r_offset.unsqueeze(1) |
|
|
| return ( |
| log_probs, |
| log_probs_idx, |
| chunk_out, |
| chunk_out_lens, |
| r_offset, |
| r_att_cache, |
| r_cnn_cache, |
| r_cache_mask, |
| ) |
|
|
|
|
| class StreamingEfficientConformerEncoder(torch.nn.Module): |
|
|
| def __init__(self, model, required_cache_size, beam_size): |
| super().__init__() |
| self.ctc = model.ctc |
| self.subsampling_rate = model.encoder.embed.subsampling_rate |
| self.embed = model.encoder.embed |
| self.global_cmvn = model.encoder.global_cmvn |
| self.required_cache_size = required_cache_size |
| self.beam_size = beam_size |
| self.encoder = model.encoder |
|
|
| |
| self.stride_layer_idx = model.encoder.stride_layer_idx |
| self.stride = model.encoder.stride |
| self.num_blocks = model.encoder.num_blocks |
| self.cnn_module_kernel = model.encoder.cnn_module_kernel |
|
|
| def calculate_downsampling_factor(self, i: int) -> int: |
| factor = 1 |
| for idx, stride_idx in enumerate(self.stride_layer_idx): |
| if i > stride_idx: |
| factor *= self.stride[idx] |
| return factor |
|
|
| def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, |
| cache_mask): |
| """Streaming Encoder |
| Args: |
| chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim), |
| where `time == (chunk_size - 1) * subsample_rate + \ |
| subsample.right_context + 1` |
| chunk_lens (torch.Tensor): |
| offset (torch.Tensor): offset with shape (b, 1) |
| 1 is retained for triton deployment |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in |
| transformer/conformer attention, with shape |
| (b, elayers, head, cache_t1, d_k * 2), where |
| `head * d_k == hidden-dim` and |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, |
| (b, elayers, hidden-dim, cache_t2), where |
| `cache_t2 == cnn.lorder - 1` |
| cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size) |
| in a batch of request, each request may have different |
| history cache. Cache mask is used to indidate the effective |
| cache for each request |
| Returns: |
| torch.Tensor: log probabilities of ctc output and cutoff by beam size |
| with shape (b, chunk_size, beam) |
| torch.Tensor: index of top beam size probabilities for each timestep |
| with shape (b, chunk_size, beam) |
| torch.Tensor: output of current input xs, |
| with shape (b, chunk_size, hidden-dim). |
| torch.Tensor: new attention cache required for next chunk, with |
| same shape (b, elayers, head, cache_t1, d_k * 2) |
| as the original att_cache |
| torch.Tensor: new conformer cnn cache required for next chunk, with |
| same shape as the original cnn_cache. |
| torch.Tensor: new cache mask, with same shape as the original |
| cache mask |
| """ |
| offset = offset.squeeze(1) |
| offset *= self.calculate_downsampling_factor(self.num_blocks + 1) |
|
|
| T = chunk_xs.size(1) |
| chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) |
| |
| chunk_mask = chunk_mask.to(chunk_xs.dtype) |
| |
| |
| |
| att_cache = torch.transpose(att_cache, 0, 1) |
| cnn_cache = torch.transpose(cnn_cache, 0, 1) |
|
|
| |
| |
| xs = self.global_cmvn(chunk_xs) |
| |
| |
| xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset) |
| cache_size = att_cache.size(3) |
| masks = torch.cat((cache_mask, chunk_mask), dim=2) |
| att_mask = torch.cat((cache_mask, chunk_mask), dim=2) |
| index = offset - cache_size |
|
|
| pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1)) |
| pos_emb = pos_emb.to(dtype=xs.dtype) |
|
|
| next_cache_start = -self.required_cache_size |
| r_cache_mask = masks[:, :, next_cache_start:] |
|
|
| r_att_cache = [] |
| r_cnn_cache = [] |
| mask_pad = chunk_mask.to(torch.bool) |
| max_att_len, max_cnn_len = ( |
| 0, |
| 0, |
| ) |
| for i, layer in enumerate(self.encoder.encoders): |
| factor = self.calculate_downsampling_factor(i) |
| |
| |
| |
| |
| att_cache_trunc = 0 |
| if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1): |
| |
| |
| att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor - |
| pos_emb.size(1) + 1) |
| xs, _, new_att_cache, new_cnn_cache = layer( |
| xs, |
| att_mask, |
| pos_emb, |
| mask_pad=mask_pad, |
| att_cache=att_cache[i][:, :, ::factor, :][:, :, |
| att_cache_trunc:, :], |
| cnn_cache=cnn_cache[i, :, :, :] |
| if cnn_cache.size(0) > 0 else cnn_cache, |
| ) |
|
|
| if i in self.stride_layer_idx: |
| |
| efficient_index = self.stride_layer_idx.index(i) |
| att_mask = att_mask[:, ::self.stride[efficient_index], ::self. |
| stride[efficient_index], ] |
| mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self. |
| stride[efficient_index], ] |
| pos_emb = pos_emb[:, ::self.stride[efficient_index], :] |
|
|
| |
| new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :] |
| |
| new_cnn_cache = new_cnn_cache.unsqueeze(1) |
|
|
| |
| |
| new_att_cache = (new_att_cache.unsqueeze(3).repeat( |
| 1, 1, 1, factor, 1).flatten(2, 3)) |
| |
| new_cnn_cache = F.pad( |
| new_cnn_cache, |
| (self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0), |
| ) |
|
|
| if i == 0: |
| |
| max_att_len = new_att_cache.size(2) |
| max_cnn_len = new_cnn_cache.size(3) |
|
|
| |
| r_att_cache.append(new_att_cache[:, :, |
| -max_att_len:, :].unsqueeze(1)) |
| r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:]) |
|
|
| if self.encoder.normalize_before: |
| chunk_out = self.encoder.after_norm(xs) |
| else: |
| chunk_out = xs |
|
|
| |
| r_att_cache = torch.cat(r_att_cache, dim=1) |
| |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=1) |
|
|
| |
|
|
| |
| log_ctc_probs = self.ctc.linear(chunk_out) |
| log_probs, log_probs_idx = torch.topk(log_ctc_probs, |
| self.beam_size, |
| dim=2) |
| log_probs = log_probs.to(chunk_xs.dtype) |
|
|
| r_offset = offset + chunk_out.shape[1] |
| |
| |
| |
| chunk_out_lens = ( |
| chunk_lens // self.subsampling_rate // |
| self.calculate_downsampling_factor(self.num_blocks + 1)) |
| chunk_out_lens += 1 |
| r_offset = r_offset.unsqueeze(1) |
|
|
| return ( |
| log_probs, |
| log_probs_idx, |
| chunk_out, |
| chunk_out_lens, |
| r_offset, |
| r_att_cache, |
| r_cnn_cache, |
| r_cache_mask, |
| ) |
|
|
|
|
| class Decoder(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| decoder: TransformerDecoder, |
| ctc_weight: float = 0.5, |
| reverse_weight: float = 0.0, |
| beam_size: int = 10, |
| decoder_fastertransformer: bool = False, |
| ): |
| super().__init__() |
| self.decoder = decoder |
| self.ctc_weight = ctc_weight |
| self.reverse_weight = reverse_weight |
| self.beam_size = beam_size |
| self.decoder_fastertransformer = decoder_fastertransformer |
|
|
| def forward( |
| self, |
| encoder_out: torch.Tensor, |
| encoder_lens: torch.Tensor, |
| hyps_pad_sos_eos: torch.Tensor, |
| hyps_lens_sos: torch.Tensor, |
| r_hyps_pad_sos_eos: torch.Tensor, |
| ctc_score: torch.Tensor, |
| ): |
| """Encoder |
| Args: |
| encoder_out: B x T x F |
| encoder_lens: B |
| hyps_pad_sos_eos: B x beam x (T2+1), |
| hyps with sos & eos and padded by ignore id |
| hyps_lens_sos: B x beam, length for each hyp with sos |
| r_hyps_pad_sos_eos: B x beam x (T2+1), |
| reversed hyps with sos & eos and padded by ignore id |
| ctc_score: B x beam, ctc score for each hyp |
| Returns: |
| decoder_out: B x beam x T2 x V |
| r_decoder_out: B x beam x T2 x V |
| best_index: B |
| """ |
| B, T, F = encoder_out.shape |
| bz = self.beam_size |
| B2 = B * bz |
| encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) |
| encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) |
| encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) |
| T2 = hyps_pad_sos_eos.shape[2] - 1 |
| hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) |
| hyps_lens = hyps_lens_sos.view(B2, ) |
| hyps_pad_sos = hyps_pad[:, :-1].contiguous() |
| hyps_pad_eos = hyps_pad[:, 1:].contiguous() |
|
|
| r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) |
| r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() |
| r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() |
|
|
| decoder_out, r_decoder_out, _ = self.decoder( |
| encoder_out, |
| encoder_mask, |
| hyps_pad_sos, |
| hyps_lens, |
| r_hyps_pad_sos, |
| self.reverse_weight, |
| ) |
| |
| V = decoder_out.shape[-1] |
| decoder_out = decoder_out.view(B2, T2, V) |
| mask = ~make_pad_mask(hyps_lens, T2) |
| |
| index = torch.unsqueeze(hyps_pad_eos * mask, 2).to(torch.long) |
| score = decoder_out.gather(2, index).squeeze(2) |
| |
| score = score * mask |
| decoder_out = decoder_out.view(B, bz, T2, V) |
| if self.reverse_weight > 0: |
| |
| |
| r_decoder_out = r_decoder_out.view(B2, T2, V) |
| index = torch.unsqueeze(r_hyps_pad_eos * mask, 2).to(torch.long) |
| r_score = r_decoder_out.gather(2, index).squeeze(2) |
| r_score = r_score * mask |
| score = (score * (1 - self.reverse_weight) + |
| self.reverse_weight * r_score) |
| r_decoder_out = r_decoder_out.view(B, bz, T2, V) |
| score = torch.sum(score, axis=1) |
| score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score |
| best_index = torch.argmax(score, dim=1) |
| if self.decoder_fastertransformer: |
| return decoder_out, best_index |
| else: |
| return best_index |
|
|
|
|
| def to_numpy(tensors): |
| out = [] |
| if type(tensors) == torch.tensor: |
| tensors = [tensors] |
| for tensor in tensors: |
| if tensor.requires_grad: |
| tensor = tensor.detach().cpu().numpy() |
| else: |
| tensor = tensor.cpu().numpy() |
| out.append(tensor) |
| return out |
|
|
|
|
| def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): |
| for a, b in zip(xlist, blist): |
| try: |
| torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) |
| except AssertionError as error: |
| if tolerate_small_mismatch: |
| print(error) |
| else: |
| raise |
|
|
|
|
| def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): |
| bz = 1 |
| seq_len = 1024 |
| beam_size = args.beam_size |
| feature_size = configs["input_dim"] |
|
|
| speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) |
| speech_lens = torch.randint(low=10, |
| high=seq_len, |
| size=(bz, ), |
| dtype=torch.int32) |
| encoder = Encoder(model.encoder, model.ctc, beam_size) |
| encoder.eval() |
|
|
| torch.onnx.export( |
| encoder, |
| (speech, speech_lens), |
| encoder_onnx_path, |
| export_params=True, |
| opset_version=13, |
| do_constant_folding=True, |
| input_names=["speech", "speech_lengths"], |
| output_names=[ |
| "encoder_out", |
| "encoder_out_lens", |
| "ctc_log_probs", |
| "beam_log_probs", |
| "beam_log_probs_idx", |
| ], |
| dynamic_axes=None, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| verbose=False, |
| dynamo=False, |
| ) |
| fold_static_pulsar2_subgraphs(encoder_onnx_path) |
| simplify_pulsar2_onnx(encoder_onnx_path) |
| rewrite_pulsar2_bool_not(encoder_onnx_path) |
|
|
| with torch.no_grad(): |
| o0, o1, o2, o3, o4 = encoder(speech, speech_lens) |
|
|
| providers = ["CPUExecutionProvider"] |
| ort_session = onnxruntime.InferenceSession(encoder_onnx_path, |
| providers=providers) |
| ort_inputs = { |
| "speech": to_numpy(speech), |
| "speech_lengths": to_numpy(speech_lens), |
| } |
| ort_outs = ort_session.run(None, ort_inputs) |
|
|
| |
| test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) |
| logger.info("export offline onnx encoder succeed!") |
| onnx_config = { |
| "beam_size": args.beam_size, |
| "reverse_weight": configs["model_conf"]["reverse_weight"], |
| "ctc_weight": configs["model_conf"]["ctc_weight"], |
| } |
| return onnx_config |
|
|
|
|
| def export_online_encoder(model, configs, args, logger, encoder_onnx_path): |
| decoding_chunk_size = args.decoding_chunk_size |
| subsampling = model.encoder.embed.subsampling_rate |
| context = model.encoder.embed.right_context + 1 |
| decoding_window = (decoding_chunk_size - 1) * subsampling + context |
| batch_size = 1 |
| audio_len = decoding_window |
| feature_size = configs["input_dim"] |
| output_size = configs["encoder_conf"]["output_size"] |
| num_layers = configs["encoder_conf"]["num_blocks"] |
| |
| transformer = False |
| cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1 |
| if not cnn_module_kernel: |
| transformer = True |
| num_decoding_left_chunks = args.num_decoding_left_chunks |
| required_cache_size = decoding_chunk_size * num_decoding_left_chunks |
| if configs["encoder"] == "squeezeformer": |
| encoder = StreamingSqueezeformerEncoder(model, required_cache_size, |
| args.beam_size) |
| elif configs["encoder"] == "efficientConformer": |
| encoder = StreamingEfficientConformerEncoder(model, |
| required_cache_size, |
| args.beam_size) |
| else: |
| encoder = StreamingEncoder( |
| model, |
| required_cache_size, |
| args.beam_size, |
| transformer, |
| args.return_ctc_logprobs, |
| ) |
| encoder.eval() |
|
|
| |
| chunk_xs = torch.randn(batch_size, |
| audio_len, |
| feature_size, |
| dtype=torch.float32) |
| chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len |
|
|
| offset = torch.arange(0, batch_size, dtype=torch.int32).unsqueeze(1) |
| |
| head = configs["encoder_conf"]["attention_heads"] |
| d_k = configs["encoder_conf"]["output_size"] // head |
| att_cache = torch.randn( |
| batch_size, |
| num_layers, |
| head, |
| required_cache_size, |
| d_k * 2, |
| dtype=torch.float32, |
| ) |
| cnn_cache = torch.randn( |
| batch_size, |
| num_layers, |
| output_size, |
| cnn_module_kernel, |
| dtype=torch.float32, |
| ) |
|
|
| cache_mask = torch.ones(batch_size, |
| 1, |
| required_cache_size, |
| dtype=torch.float32) |
| input_names = [ |
| "chunk_xs", |
| "chunk_lens", |
| "offset", |
| "att_cache", |
| "cnn_cache", |
| "cache_mask", |
| ] |
| output_names = [ |
| "log_probs", |
| "log_probs_idx", |
| "chunk_out", |
| "chunk_out_lens", |
| "r_offset", |
| "r_att_cache", |
| "r_cnn_cache", |
| "r_cache_mask", |
| ] |
| if args.return_ctc_logprobs: |
| output_names = [ |
| "ctc_log_probs", |
| "chunk_out", |
| "chunk_out_lens", |
| "r_offset", |
| "r_att_cache", |
| "r_cnn_cache", |
| "r_cache_mask", |
| ] |
| input_tensors = ( |
| chunk_xs, |
| chunk_lens, |
| offset, |
| att_cache, |
| cnn_cache, |
| cache_mask, |
| ) |
| if transformer: |
| assert (args.return_ctc_logprobs is |
| False), "return_ctc_logprobs is not supported in transformer" |
| output_names.pop(6) |
|
|
| all_names = input_names + output_names |
| dynamic_axes = {} |
| for name in all_names: |
| |
| |
| dynamic_axes[name] = {0: "B"} |
|
|
| torch.onnx.export( |
| encoder, |
| input_tensors, |
| encoder_onnx_path, |
| export_params=True, |
| opset_version=14, |
| do_constant_folding=True, |
| input_names=input_names, |
| output_names=output_names, |
| |
| dynamic_axes=None, |
| verbose=False, |
| dynamo=False, |
| ) |
| fold_static_pulsar2_subgraphs(encoder_onnx_path) |
| simplify_pulsar2_onnx(encoder_onnx_path) |
| rewrite_pulsar2_bool_not(encoder_onnx_path) |
|
|
| with torch.no_grad(): |
| torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache, |
| cnn_cache, cache_mask) |
| if transformer: |
| torch_outs = list(torch_outs).pop(6) |
| ort_session = onnxruntime.InferenceSession( |
| encoder_onnx_path, providers=["CPUExecutionProvider"]) |
| ort_inputs = {} |
|
|
| input_tensors = to_numpy(input_tensors) |
| for idx, name in enumerate(input_names): |
| ort_inputs[name] = input_tensors[idx] |
| if transformer: |
| del ort_inputs["cnn_cache"] |
| ort_outs = ort_session.run(None, ort_inputs) |
| test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05) |
| logger.info("export to onnx streaming encoder succeed!") |
| onnx_config = { |
| "subsampling_rate": subsampling, |
| "context": context, |
| "decoding_chunk_size": decoding_chunk_size, |
| "num_decoding_left_chunks": num_decoding_left_chunks, |
| "beam_size": args.beam_size, |
| "feat_size": feature_size, |
| "decoding_window": decoding_window, |
| "cnn_module_kernel_cache": cnn_module_kernel, |
| "return_ctc_logprobs": args.return_ctc_logprobs, |
| } |
| return onnx_config |
|
|
|
|
| def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, |
| decoder_fastertransformer): |
| bz, seq_len = 1, 32 |
| beam_size = args.beam_size |
| decoder = Decoder( |
| model.decoder, |
| model.ctc_weight, |
| model.reverse_weight, |
| beam_size, |
| decoder_fastertransformer, |
| ) |
| decoder.eval() |
|
|
| hyps_pad_sos_eos = torch.randint(low=3, |
| high=1000, |
| size=(bz, beam_size, seq_len), |
| dtype=torch.int32) |
| hyps_lens_sos = torch.randint(low=3, |
| high=seq_len, |
| size=(bz, beam_size), |
| dtype=torch.int32) |
| r_hyps_pad_sos_eos = torch.randint(low=3, |
| high=1000, |
| size=(bz, beam_size, seq_len), |
| dtype=torch.int32) |
|
|
| output_size = configs["encoder_conf"]["output_size"] |
| encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) |
| encoder_out_lens = torch.randint(low=3, |
| high=seq_len, |
| size=(bz, ), |
| dtype=torch.int32) |
| ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) |
|
|
| input_names = [ |
| "encoder_out", |
| "encoder_out_lens", |
| "hyps_pad_sos_eos", |
| "hyps_lens_sos", |
| "r_hyps_pad_sos_eos", |
| "ctc_score", |
| ] |
| output_names = ["best_index"] |
| if decoder_fastertransformer: |
| output_names.insert(0, "decoder_out") |
|
|
| torch.onnx.export( |
| decoder, |
| ( |
| encoder_out, |
| encoder_out_lens, |
| hyps_pad_sos_eos, |
| hyps_lens_sos, |
| r_hyps_pad_sos_eos, |
| ctc_score, |
| ), |
| decoder_onnx_path, |
| export_params=True, |
| opset_version=13, |
| do_constant_folding=True, |
| input_names=input_names, |
| output_names=output_names, |
| dynamic_axes=None, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| verbose=False, |
| dynamo=False, |
| ) |
| fold_static_pulsar2_subgraphs(decoder_onnx_path) |
| simplify_pulsar2_onnx(decoder_onnx_path) |
| rewrite_pulsar2_bool_not(decoder_onnx_path) |
| rewrite_pulsar2_bool_and(decoder_onnx_path) |
| with torch.no_grad(): |
| o0 = decoder( |
| encoder_out, |
| encoder_out_lens, |
| hyps_pad_sos_eos, |
| hyps_lens_sos, |
| r_hyps_pad_sos_eos, |
| ctc_score, |
| ) |
| providers = ["CPUExecutionProvider"] |
| ort_session = onnxruntime.InferenceSession(decoder_onnx_path, |
| providers=providers) |
|
|
| input_tensors = [ |
| encoder_out, |
| encoder_out_lens, |
| hyps_pad_sos_eos, |
| hyps_lens_sos, |
| r_hyps_pad_sos_eos, |
| ctc_score, |
| ] |
| ort_inputs = {} |
| input_tensors = to_numpy(input_tensors) |
| for idx, name in enumerate(input_names): |
| ort_inputs[name] = input_tensors[idx] |
|
|
| |
| |
| |
| if model.reverse_weight == 0: |
| del ort_inputs["r_hyps_pad_sos_eos"] |
| ort_outs = ort_session.run(None, ort_inputs) |
|
|
| |
| if decoder_fastertransformer: |
| test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) |
| else: |
| test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) |
| logger.info("export to onnx decoder succeed!") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="export x86_gpu model") |
| parser.add_argument( |
| "--pretrained_model_dir", |
| default=DEFAULT_PRETRAINED_MODEL_DIR, |
| help=("pretrained model directory containing train.yaml, final.pt, " |
| "and global_cmvn"), |
| ) |
| parser.add_argument( |
| "--pretrained_model_url", |
| default=DEFAULT_PRETRAINED_MODEL_URL, |
| help="pretrained model tar.gz URL used when pretrained_model_dir is missing", |
| ) |
| parser.add_argument( |
| "--reverse_weight", |
| default=-1.0, |
| type=float, |
| required=False, |
| help="reverse weight for bitransformer," + |
| "default value is in config file", |
| ) |
| parser.add_argument( |
| "--ctc_weight", |
| default=-1.0, |
| type=float, |
| required=False, |
| help="ctc weight, default value is in config file", |
| ) |
| parser.add_argument( |
| "--beam_size", |
| default=10, |
| type=int, |
| required=False, |
| help="beam size would be ctc output size", |
| ) |
| parser.add_argument( |
| "--output_onnx_dir", |
| default="onnx_model", |
| help="output onnx encoder and decoder directory", |
| ) |
| |
| |
| |
| |
| |
| |
| parser.add_argument( |
| "--decoding_chunk_size", |
| default=16, |
| type=int, |
| required=False, |
| help="the decoding chunk size, <=0 is not supported", |
| ) |
| parser.add_argument( |
| "--num_decoding_left_chunks", |
| default=5, |
| type=int, |
| required=False, |
| help="number of left chunks, <= 0 is not supported", |
| ) |
| parser.add_argument( |
| "--decoder_fastertransformer", |
| action="store_true", |
| help="return decoder_out and best_index for ft", |
| ) |
| parser.add_argument( |
| "--return_ctc_logprobs", |
| action="store_true", |
| help="return full ctc_log_probs for TLG streaming encoder", |
| ) |
| args = parser.parse_args() |
| prepare_pretrained_model(args) |
|
|
| torch.manual_seed(0) |
| torch.set_printoptions(precision=10) |
|
|
| with open(args.config, "r") as fin: |
| configs = yaml.load(fin, Loader=yaml.FullLoader) |
| if os.path.exists(args.cmvn_file): |
| if 'cmvn' not in configs: |
| configs['cmvn'] = "global_cmvn" |
| configs['cmvn_conf'] = {} |
| else: |
| assert configs['cmvn'] == "global_cmvn" |
| assert configs['cmvn_conf'] is not None |
| configs['cmvn_conf']["cmvn_file"] = args.cmvn_file |
| configs['cmvn_conf'].setdefault( |
| "is_json_cmvn", configs.get("is_json_cmvn", True)) |
| elif configs.get('cmvn', None) == 'global_cmvn': |
| raise FileNotFoundError( |
| f"Expected global_cmvn in pretrained model dir: {args.cmvn_file}") |
| if (args.reverse_weight != -1.0 |
| and "reverse_weight" in configs["model_conf"]): |
| configs["model_conf"]["reverse_weight"] = args.reverse_weight |
| print("Update reverse weight to", args.reverse_weight) |
| if args.ctc_weight != -1: |
| print("Update ctc weight to ", args.ctc_weight) |
| configs["model_conf"]["ctc_weight"] = args.ctc_weight |
| configs["encoder_conf"]["use_dynamic_chunk"] = False |
|
|
| model, configs = init_model(args, configs) |
| model.eval() |
|
|
| if not os.path.exists(args.output_onnx_dir): |
| os.mkdir(args.output_onnx_dir) |
| |
| export_enc_func = None |
| |
| assert args.decoding_chunk_size > 0 |
| assert args.num_decoding_left_chunks > 0 |
| export_enc_func = export_online_encoder |
| encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_online.onnx") |
| onnx_config = export_enc_func(model, configs, args, logger, |
| encoder_onnx_path) |
| |
| |
| export_enc_func = export_offline_encoder |
| encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_offline.onnx") |
| onnx_config = export_enc_func(model, configs, args, logger, |
| encoder_onnx_path) |
| |
|
|
| decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx") |
| export_rescoring_decoder( |
| model, |
| configs, |
| args, |
| logger, |
| decoder_onnx_path, |
| args.decoder_fastertransformer, |
| ) |
|
|
| config_dir = os.path.join(args.output_onnx_dir, "config.yaml") |
| with open(config_dir, "w") as out: |
| yaml.dump(onnx_config, out) |
|
|