WeNet / export_onnx.py
inoryQwQ's picture
First commit
3c50954
Raw
History Blame Contribute Delete
65.4 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import tarfile
import urllib.request
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
from wenet.utils.init_model import init_model
from wenet.utils.mask import make_pad_mask
from typing import List, Tuple
try:
import onnx
import onnxruntime
from onnx import helper, numpy_helper
from onnxsim import simplify
except ImportError:
print("Please install onnxruntime!")
sys.exit(1)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
DEFAULT_PRETRAINED_MODEL_URL = (
"https://huggingface.co/openspeech/wenet-models/resolve/main/"
"aishell_u2pp_conformer_exp.tar.gz")
DEFAULT_PRETRAINED_MODEL_DIR = "pretrained/aishell_u2pp_conformer_exp"
def safe_extract_tar(tar, output_dir):
output_dir = os.path.abspath(output_dir)
for member in tar.getmembers():
member_path = os.path.abspath(os.path.join(output_dir, member.name))
if not member_path.startswith(output_dir + os.sep):
raise RuntimeError(f"Unsafe tar member path: {member.name}")
tar.extractall(output_dir)
def download_file(url, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Downloading pretrained model from {url}")
print(f"Saving to {output_path}")
urllib.request.urlretrieve(url, output_path)
def prepare_pretrained_model(args):
model_dir = args.pretrained_model_dir
archive_dir = os.path.dirname(model_dir.rstrip(os.sep)) or "."
archive_path = os.path.join(
archive_dir, os.path.basename(model_dir.rstrip(os.sep)) + ".tar.gz")
if not os.path.exists(model_dir):
if not os.path.exists(archive_path):
download_file(args.pretrained_model_url, archive_path)
print(f"Extracting pretrained model to {archive_dir}")
with tarfile.open(archive_path, "r:gz") as tar:
safe_extract_tar(tar, archive_dir)
args.config = os.path.join(model_dir, "train.yaml")
args.checkpoint = os.path.join(model_dir, "final.pt")
args.cmvn_file = os.path.join(model_dir, "global_cmvn")
missing = [path for path in (args.config, args.checkpoint)
if not os.path.exists(path)]
if missing:
raise FileNotFoundError(
"Missing pretrained model files: " + ", ".join(missing))
print(f"Using config: {args.config}")
print(f"Using checkpoint: {args.checkpoint}")
if os.path.exists(args.cmvn_file):
print(f"Using CMVN: {args.cmvn_file}")
def _constant_node_value(node):
if node is None or node.op_type != "Constant":
return None
for attr in node.attribute:
if attr.name == "value":
return numpy_helper.to_array(attr.t)
return None
def _attribute_value(attr):
return helper.get_attribute_value(attr)
def _get_attr(node, name, default=None):
for attr in node.attribute:
if attr.name == name:
return _attribute_value(attr)
return default
def _cast_to_onnx_dtype(value, to_dtype):
tensor_type = onnx.TensorProto.DataType.Name(to_dtype).lower()
dtype_map = {
"float": np.float32,
"double": np.float64,
"float16": np.float16,
"int64": np.int64,
"int32": np.int32,
"int16": np.int16,
"int8": np.int8,
"uint64": np.uint64,
"uint32": np.uint32,
"uint16": np.uint16,
"uint8": np.uint8,
"bool": np.bool_,
}
if tensor_type not in dtype_map:
return None
return value.astype(dtype_map[tensor_type])
def _shape_from_value_info(value_info):
if not value_info.type.HasField("tensor_type"):
return None
if not value_info.type.tensor_type.shape.dim:
return None
shape = []
for dim in value_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value") and dim.dim_value > 0:
shape.append(dim.dim_value)
else:
return None
return tuple(shape)
def _collect_static_shapes(model):
inferred = onnx.shape_inference.infer_shapes(model)
shapes = {}
for value_info in list(inferred.graph.input) + list(
inferred.graph.value_info) + list(inferred.graph.output):
shape = _shape_from_value_info(value_info)
if shape is not None:
shapes[value_info.name] = shape
for initializer in inferred.graph.initializer:
shapes[initializer.name] = tuple(initializer.dims)
return shapes
def _eval_static_node(node, inputs, static_shapes):
if node.op_type == "Constant":
return _constant_node_value(node)
if node.op_type != "Shape" and any(value is None for value in inputs):
return None
try:
if node.op_type == "Add":
return np.add(inputs[0], inputs[1])
if node.op_type == "Sub":
return np.subtract(inputs[0], inputs[1])
if node.op_type == "Mul":
return np.multiply(inputs[0], inputs[1])
if node.op_type == "Div":
return np.divide(inputs[0], inputs[1])
if node.op_type == "Equal":
return np.equal(inputs[0], inputs[1])
if node.op_type == "Greater":
return np.greater(inputs[0], inputs[1])
if node.op_type == "GreaterOrEqual":
return np.greater_equal(inputs[0], inputs[1])
if node.op_type == "Less":
return np.less(inputs[0], inputs[1])
if node.op_type == "LessOrEqual":
return np.less_equal(inputs[0], inputs[1])
if node.op_type == "Where":
return np.where(inputs[0], inputs[1], inputs[2])
if node.op_type == "Concat":
axis = _get_attr(node, "axis", 0)
return np.concatenate(inputs, axis=axis)
if node.op_type == "Unsqueeze":
axes = _get_attr(node, "axes")
if axes is None and len(inputs) > 1:
axes = inputs[1]
axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1))
return np.expand_dims(inputs[0], axis=axes)
if node.op_type == "Squeeze":
axes = _get_attr(node, "axes")
if axes is None and len(inputs) > 1:
axes = inputs[1]
if axes is None:
return np.squeeze(inputs[0])
axes = tuple(int(axis) for axis in np.asarray(axes).reshape(-1))
return np.squeeze(inputs[0], axis=axes)
if node.op_type == "Cast":
return _cast_to_onnx_dtype(inputs[0], _get_attr(node, "to"))
if node.op_type == "Reshape":
return np.reshape(inputs[0], tuple(int(i) for i in inputs[1]))
if node.op_type == "Shape":
if inputs[0] is not None:
shape = inputs[0].shape
else:
shape = static_shapes.get(node.input[0])
if shape is None:
return None
return np.asarray(shape, dtype=np.int64)
if node.op_type == "Slice":
data = inputs[0]
starts = np.asarray(inputs[1]).reshape(-1)
ends = np.asarray(inputs[2]).reshape(-1)
axes = (np.asarray(inputs[3]).reshape(-1)
if len(inputs) > 3 and inputs[3] is not None else
np.arange(len(starts)))
steps = (np.asarray(inputs[4]).reshape(-1)
if len(inputs) > 4 and inputs[4] is not None else
np.ones(len(starts), dtype=np.int64))
slices = [slice(None)] * data.ndim
for start, end, axis, step in zip(starts, ends, axes, steps):
axis = int(axis)
start = int(start)
end = int(end)
step = int(step)
if end >= np.iinfo(np.int32).max:
end = None
if end <= np.iinfo(np.int32).min:
end = None
slices[axis] = slice(start, end, step)
return data[tuple(slices)]
if node.op_type == "Gather":
axis = _get_attr(node, "axis", 0)
return np.take(inputs[0], inputs[1], axis=axis)
except Exception:
return None
return None
def _constant_node(output_name, value, name):
const_tensor = numpy_helper.from_array(np.asarray(value),
name=output_name + "_value")
return helper.make_node("Constant",
inputs=[],
outputs=[output_name],
name=name,
value=const_tensor)
def _node_attributes(node):
return {attr.name: helper.get_attribute_value(attr) for attr in node.attribute}
def _copy_node(node, inputs=None, outputs=None, name=None):
copied = copy.deepcopy(node)
if inputs is not None:
del copied.input[:]
copied.input.extend(inputs)
if outputs is not None:
del copied.output[:]
copied.output.extend(outputs)
if name is not None:
copied.name = name
return copied
def _producer_map(model):
return {output: node for node in model.graph.node for output in node.output}
def _unsqueeze_greater_equal_pattern(producer, value_name):
unsqueeze = producer.get(value_name)
if unsqueeze is None or unsqueeze.op_type != "Unsqueeze":
return None, None
compare = producer.get(unsqueeze.input[0])
if compare is None or compare.op_type != "GreaterOrEqual":
return None, None
return unsqueeze, compare
def rewrite_pulsar2_bool_not(onnx_path):
"""Remove simple Not nodes that Pulsar2 quantization can cast to float.
The encoder mask contains Not(Unsqueeze(GreaterOrEqual(...))) and another
Not over a sliced version of that mask. Pulsar2 can quantize the Not input
to FP32 and then fail because bitwise Not only accepts bool/integer tensors.
Rewriting those patterns keeps the graph boolean-equivalent without Not.
"""
model = onnx.load(onnx_path)
producer = _producer_map(model)
rewritten = 0
new_nodes = []
for node in model.graph.node:
if node.op_type != "Not":
new_nodes.append(node)
continue
compare = producer.get(node.input[0])
if compare is not None and compare.op_type == "GreaterOrEqual":
less = helper.make_node("Less",
inputs=list(compare.input),
outputs=list(node.output),
name=node.name + "_less",
**_node_attributes(compare))
new_nodes.append(less)
rewritten += 1
continue
unsqueeze, compare = _unsqueeze_greater_equal_pattern(
producer, node.input[0])
if unsqueeze is not None:
less_output = node.output[0] + "_less"
less = helper.make_node("Less",
inputs=list(compare.input),
outputs=[less_output],
name=node.name + "_less",
**_node_attributes(compare))
rewritten_unsqueeze = _copy_node(
unsqueeze,
inputs=[less_output] + list(unsqueeze.input[1:]),
outputs=list(node.output),
name=node.name + "_unsqueeze")
new_nodes.extend([less, rewritten_unsqueeze])
rewritten += 1
continue
slice_1 = producer.get(node.input[0])
slice_0 = producer.get(slice_1.input[0]) if slice_1 else None
inner_not = producer.get(slice_0.input[0]) if slice_0 else None
if (slice_1 is not None and slice_1.op_type == "Slice"
and slice_0 is not None and slice_0.op_type == "Slice"
and inner_not is not None and inner_not.op_type == "Not"):
unsqueeze, _ = _unsqueeze_greater_equal_pattern(
producer, inner_not.input[0])
if unsqueeze is not None:
slice_0_output = node.output[0] + "_slice0"
rewritten_slice_0 = _copy_node(
slice_0,
inputs=[unsqueeze.output[0]] + list(slice_0.input[1:]),
outputs=[slice_0_output],
name=node.name + "_slice0")
rewritten_slice_1 = _copy_node(
slice_1,
inputs=[slice_0_output] + list(slice_1.input[1:]),
outputs=list(node.output),
name=node.name + "_slice1")
new_nodes.extend([rewritten_slice_0, rewritten_slice_1])
rewritten += 1
continue
new_nodes.append(node)
if rewritten:
del model.graph.node[:]
model.graph.node.extend(new_nodes)
onnx.checker.check_model(model)
onnx.save(model, onnx_path)
print(f"Rewrote {rewritten} bool Not node(s) in {onnx_path}")
def rewrite_pulsar2_bool_and(onnx_path):
"""Replace boolean And with arithmetic comparison for Pulsar2 quantization."""
model = onnx.load(onnx_path)
rewritten = 0
new_nodes = []
for node in model.graph.node:
if node.op_type != "And" or len(node.input) != 2 or len(
node.output) != 1:
new_nodes.append(node)
continue
left = node.output[0] + "_left_i32"
right = node.output[0] + "_right_i32"
added = node.output[0] + "_sum"
threshold = node.output[0] + "_threshold"
new_nodes.append(
helper.make_node("Cast",
inputs=[node.input[0]],
outputs=[left],
name=node.name + "_cast_left",
to=onnx.TensorProto.INT32))
new_nodes.append(
helper.make_node("Cast",
inputs=[node.input[1]],
outputs=[right],
name=node.name + "_cast_right",
to=onnx.TensorProto.INT32))
new_nodes.append(
helper.make_node("Add",
inputs=[left, right],
outputs=[added],
name=node.name + "_add"))
new_nodes.append(
_constant_node(threshold, np.asarray(1, dtype=np.int32),
node.name + "_threshold"))
new_nodes.append(
helper.make_node("Greater",
inputs=[added, threshold],
outputs=list(node.output),
name=node.name + "_greater"))
rewritten += 1
if rewritten:
del model.graph.node[:]
model.graph.node.extend(new_nodes)
onnx.checker.check_model(model)
onnx.save(model, onnx_path)
print(f"Rewrote {rewritten} bool And node(s) in {onnx_path}")
def simplify_pulsar2_onnx(onnx_path):
model = onnx.load(onnx_path)
sim_model, ok = simplify(model)
if not ok:
raise RuntimeError(f"onnxsim failed to validate {onnx_path}")
onnx.checker.check_model(sim_model)
onnx.save(sim_model, onnx_path)
print(f"Simplified {onnx_path} for Pulsar2")
def fold_static_pulsar2_subgraphs(onnx_path):
"""Fold static ONNX patterns that Pulsar2 5.0 cannot infer reliably.
Pulsar2 5.0 can fail shape inference on ConstantOfShape when its input is
a constant tensor value instead of an initializer. The legacy exporter emits
this pattern for masks/padding in the encoder graphs. It can also fail when
an Expand shape is produced by a constant-only subgraph such as
Mul/Equal/Where. Fold those static pieces before handing the model to
Pulsar2.
"""
model = onnx.load(onnx_path)
static_shapes = _collect_static_shapes(model)
constants = {
initializer.name: numpy_helper.to_array(initializer)
for initializer in model.graph.initializer
}
folded = 0
new_nodes = []
for node in model.graph.node:
inputs = [constants.get(name) for name in node.input]
if node.op_type == "ConstantOfShape" and node.input:
shape_value = inputs[0]
if shape_value is not None:
fill_value = np.array(0, dtype=np.float32)
for attr in node.attribute:
if attr.name == "value":
fill_value = numpy_helper.to_array(attr.t)
break
shape = tuple(int(dim)
for dim in np.asarray(shape_value).reshape(-1))
value = np.full(shape,
fill_value.reshape(-1)[0],
dtype=fill_value.dtype)
else:
value = None
else:
value = _eval_static_node(node, inputs, static_shapes)
if value is None or len(node.output) != 1:
new_nodes.append(node)
continue
constants[node.output[0]] = value
new_nodes.append(_constant_node(node.output[0], value, node.name))
folded += 1
if folded:
del model.graph.node[:]
model.graph.node.extend(new_nodes)
onnx.checker.check_model(model)
onnx.save(model, onnx_path)
print(f"Folded {folded} static node(s) in {onnx_path}")
class Encoder(torch.nn.Module):
def __init__(self, encoder: BaseEncoder, ctc: CTC, beam_size: int = 10):
super().__init__()
self.encoder = encoder
self.ctc = ctc
self.beam_size = beam_size
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
):
"""Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
Returns:
encoder_out: B x T x F
encoder_out_lens: B
ctc_log_probs: B x T x V
beam_log_probs: B x T x beam_size
beam_log_probs_idx: B x T x beam_size
"""
encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1,
-1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# ctc_log_probs = self.ctc.log_softmax(encoder_out)
ctc_log_probs = self.ctc.linear(encoder_out)
encoder_out_lens = encoder_out_lens.int()
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
self.beam_size,
dim=2)
return (
encoder_out,
encoder_out_lens,
ctc_log_probs,
beam_log_probs,
beam_log_probs_idx,
)
class StreamingEncoder(torch.nn.Module):
def __init__(
self,
model,
required_cache_size,
beam_size,
transformer=False,
return_ctc_logprobs=False,
):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
self.transformer = transformer
self.return_ctc_logprobs = return_ctc_logprobs
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
cache_mask):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
cache_size = att_cache.size(3) # required cache size
masks = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)
next_cache_start = -self.required_cache_size
r_cache_mask = masks[:, :, next_cache_start:]
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
i_kv_cache = att_cache[i]
size = att_cache.size(-1) // 2
kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:])
xs, _, new_kv_cache, new_cnn_cache = layer(
xs,
masks,
pos_emb,
att_cache=kv_cache,
cnn_cache=cnn_cache[i],
)
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
new_att_cache = torch.cat(new_kv_cache, dim=-1)
r_att_cache.append(
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
if not self.transformer:
r_cnn_cache.append(new_cnn_cache.unsqueeze(1))
if self.encoder.normalize_before:
chunk_out = self.encoder.after_norm(xs)
else:
chunk_out = xs
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
if not self.transformer:
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
# <---------forward_chunk END--------->
# log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_ctc_probs = self.ctc.linear(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)
r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)
if self.return_ctc_logprobs:
return (
log_ctc_probs,
chunk_out,
chunk_out_lens,
r_offset,
r_att_cache,
r_cnn_cache,
r_cache_mask,
)
else:
return (
log_probs,
log_probs_idx,
chunk_out,
chunk_out_lens,
r_offset,
r_att_cache,
r_cnn_cache,
r_cache_mask,
)
class StreamingSqueezeformerEncoder(torch.nn.Module):
def __init__(self, model, required_cache_size, beam_size):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
self.reduce_idx = model.encoder.reduce_idx
self.recover_idx = model.encoder.recover_idx
if self.reduce_idx is None:
self.time_reduce = None
else:
if self.recover_idx is None:
self.time_reduce = "normal" # no recovery at the end
else:
self.time_reduce = "recover" # recovery at the end
assert len(self.reduce_idx) == len(self.recover_idx)
def calculate_downsampling_factor(self, i: int) -> int:
if self.reduce_idx is None:
return 1
else:
reduce_exp, recover_exp = 0, 0
for exp, rd_idx in enumerate(self.reduce_idx):
if i >= rd_idx:
reduce_exp = exp + 1
if self.recover_idx is not None:
for exp, rc_idx in enumerate(self.recover_idx):
if i >= rc_idx:
recover_exp = exp + 1
return int(2**(reduce_exp - recover_exp))
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
cache_mask):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
elayers, cache_size = att_cache.size(0), att_cache.size(3)
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)
next_cache_start = -self.required_cache_size
r_cache_mask = att_mask[:, :, next_cache_start:]
r_att_cache = []
r_cnn_cache = []
mask_pad = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
mask_pad = mask_pad.unsqueeze(1)
max_att_len: int = 0
recover_activations: List[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]] = []
index = 0
xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
xs = self.encoder.preln(xs)
for i, layer in enumerate(self.encoder.encoders):
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append(
(xs, att_mask, pos_emb, mask_pad))
(
xs,
xs_lens,
att_mask,
mask_pad,
) = self.encoder.time_reduction_layer(
xs, xs_lens, att_mask, mask_pad)
pos_emb = pos_emb[:, ::2, :]
if self.encoder.pos_enc_layer_type == "rel_pos_repaired":
pos_emb = pos_emb[:, :xs.size(1) * 2 - 1, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == "recover" and i in self.recover_idx:
index -= 1
(
recover_tensor,
recover_att_mask,
recover_pos_emb,
recover_mask_pad,
) = recover_activations[index]
# recover output length for ctc decode
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.encoder.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
att_mask = recover_att_mask
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
factor = self.calculate_downsampling_factor(i)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
att_cache=att_cache[i][:, :, ::factor, :]
[:, :, :pos_emb.size(1) - xs.size(1), :]
if elayers > 0 else att_cache[:, :, ::factor, :],
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache,
)
cached_att = new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(1)
cached_att = (cached_att.unsqueeze(3).repeat(1, 1, 1, factor,
1).flatten(2, 3))
if i == 0:
# record length for the first block as max length
max_att_len = cached_att.size(2)
r_att_cache.append(cached_att[:, :, :max_att_len, :].unsqueeze(1))
r_cnn_cache.append(cached_cnn)
chunk_out = xs
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
# <---------forward_chunk END--------->
# log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_ctc_probs = self.ctc.linear(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)
r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = chunk_lens // self.subsampling_rate
r_offset = r_offset.unsqueeze(1)
return (
log_probs,
log_probs_idx,
chunk_out,
chunk_out_lens,
r_offset,
r_att_cache,
r_cnn_cache,
r_cache_mask,
)
class StreamingEfficientConformerEncoder(torch.nn.Module):
def __init__(self, model, required_cache_size, beam_size):
super().__init__()
self.ctc = model.ctc
self.subsampling_rate = model.encoder.embed.subsampling_rate
self.embed = model.encoder.embed
self.global_cmvn = model.encoder.global_cmvn
self.required_cache_size = required_cache_size
self.beam_size = beam_size
self.encoder = model.encoder
# Efficient Conformer
self.stride_layer_idx = model.encoder.stride_layer_idx
self.stride = model.encoder.stride
self.num_blocks = model.encoder.num_blocks
self.cnn_module_kernel = model.encoder.cnn_module_kernel
def calculate_downsampling_factor(self, i: int) -> int:
factor = 1
for idx, stride_idx in enumerate(self.stride_layer_idx):
if i > stride_idx:
factor *= self.stride[idx]
return factor
def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
cache_mask):
"""Streaming Encoder
Args:
chunk_xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
chunk_lens (torch.Tensor):
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset = offset.squeeze(1) # (b, )
offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
T = chunk_xs.size(1)
chunk_mask = ~make_pad_mask(chunk_lens, T).unsqueeze(1) # (b, 1, T)
# B X 1 X T
chunk_mask = chunk_mask.to(chunk_xs.dtype)
# transpose batch & num_layers dim
# Shape(att_cache): (elayers, b, head, cache_t1, d_k * 2)
# Shape(cnn_cache): (elayers, b, outsize, cnn_kernel)
att_cache = torch.transpose(att_cache, 0, 1)
cnn_cache = torch.transpose(cnn_cache, 0, 1)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs = self.global_cmvn(chunk_xs)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs, pos_emb, chunk_mask = self.embed(xs, chunk_mask, offset)
cache_size = att_cache.size(3) # required cache size
masks = torch.cat((cache_mask, chunk_mask), dim=2)
att_mask = torch.cat((cache_mask, chunk_mask), dim=2)
index = offset - cache_size
pos_emb = self.embed.position_encoding(index, cache_size + xs.size(1))
pos_emb = pos_emb.to(dtype=xs.dtype)
next_cache_start = -self.required_cache_size
r_cache_mask = masks[:, :, next_cache_start:]
r_att_cache = []
r_cnn_cache = []
mask_pad = chunk_mask.to(torch.bool)
max_att_len, max_cnn_len = (
0,
0,
) # for repeat_interleave of new_att_cache
for i, layer in enumerate(self.encoder.encoders):
factor = self.calculate_downsampling_factor(i)
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (b, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
# shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
att_cache_trunc = 0
if xs.size(1) + att_cache.size(3) / factor > pos_emb.size(1):
# The time step is not divisible by the downsampling multiple
# We propose to double the chunk_size.
att_cache_trunc = (xs.size(1) + att_cache.size(3) // factor -
pos_emb.size(1) + 1)
xs, _, new_att_cache, new_cnn_cache = layer(
xs,
att_mask,
pos_emb,
mask_pad=mask_pad,
att_cache=att_cache[i][:, :, ::factor, :][:, :,
att_cache_trunc:, :],
cnn_cache=cnn_cache[i, :, :, :]
if cnn_cache.size(0) > 0 else cnn_cache,
)
if i in self.stride_layer_idx:
# compute time dimension for next block
efficient_index = self.stride_layer_idx.index(i)
att_mask = att_mask[:, ::self.stride[efficient_index], ::self.
stride[efficient_index], ]
mask_pad = mask_pad[:, ::self.stride[efficient_index], ::self.
stride[efficient_index], ]
pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
# shape(new_att_cache) = [batch, head, time2, outdim]
new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
# shape(new_cnn_cache) = [batch, 1, outdim, cache_t2]
new_cnn_cache = new_cnn_cache.unsqueeze(1) # shape(1):layerID
# use repeat_interleave to new_att_cache
# new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
new_att_cache = (new_att_cache.unsqueeze(3).repeat(
1, 1, 1, factor, 1).flatten(2, 3))
# padding new_cnn_cache to cnn.lorder for casual convolution
new_cnn_cache = F.pad(
new_cnn_cache,
(self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0),
)
if i == 0:
# record length for the first block as max length
max_att_len = new_att_cache.size(2)
max_cnn_len = new_cnn_cache.size(3)
# update real shape of att_cache and cnn_cache
r_att_cache.append(new_att_cache[:, :,
-max_att_len:, :].unsqueeze(1))
r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
if self.encoder.normalize_before:
chunk_out = self.encoder.after_norm(xs)
else:
chunk_out = xs
# shape of r_att_cache: (b, elayers, head, time2, outdim)
r_att_cache = torch.cat(r_att_cache, dim=1) # concat on layers idx
# shape of r_cnn_cache: (b, elayers, outdim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=1) # concat on layers
# <---------forward_chunk END--------->
# log_ctc_probs = self.ctc.log_softmax(chunk_out)
log_ctc_probs = self.ctc.linear(chunk_out)
log_probs, log_probs_idx = torch.topk(log_ctc_probs,
self.beam_size,
dim=2)
log_probs = log_probs.to(chunk_xs.dtype)
r_offset = offset + chunk_out.shape[1]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens = (
chunk_lens // self.subsampling_rate //
self.calculate_downsampling_factor(self.num_blocks + 1))
chunk_out_lens += 1
r_offset = r_offset.unsqueeze(1)
return (
log_probs,
log_probs_idx,
chunk_out,
chunk_out_lens,
r_offset,
r_att_cache,
r_cnn_cache,
r_cache_mask,
)
class Decoder(torch.nn.Module):
def __init__(
self,
decoder: TransformerDecoder,
ctc_weight: float = 0.5,
reverse_weight: float = 0.0,
beam_size: int = 10,
decoder_fastertransformer: bool = False,
):
super().__init__()
self.decoder = decoder
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.beam_size = beam_size
self.decoder_fastertransformer = decoder_fastertransformer
def forward(
self,
encoder_out: torch.Tensor,
encoder_lens: torch.Tensor,
hyps_pad_sos_eos: torch.Tensor,
hyps_lens_sos: torch.Tensor,
r_hyps_pad_sos_eos: torch.Tensor,
ctc_score: torch.Tensor,
):
"""Encoder
Args:
encoder_out: B x T x F
encoder_lens: B
hyps_pad_sos_eos: B x beam x (T2+1),
hyps with sos & eos and padded by ignore id
hyps_lens_sos: B x beam, length for each hyp with sos
r_hyps_pad_sos_eos: B x beam x (T2+1),
reversed hyps with sos & eos and padded by ignore id
ctc_score: B x beam, ctc score for each hyp
Returns:
decoder_out: B x beam x T2 x V
r_decoder_out: B x beam x T2 x V
best_index: B
"""
B, T, F = encoder_out.shape
bz = self.beam_size
B2 = B * bz
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
T2 = hyps_pad_sos_eos.shape[2] - 1
hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
hyps_lens = hyps_lens_sos.view(B2, )
hyps_pad_sos = hyps_pad[:, :-1].contiguous()
hyps_pad_eos = hyps_pad[:, 1:].contiguous()
r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out,
encoder_mask,
hyps_pad_sos,
hyps_lens,
r_hyps_pad_sos,
self.reverse_weight,
)
# decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
V = decoder_out.shape[-1]
decoder_out = decoder_out.view(B2, T2, V)
mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
# mask index, remove ignore id
index = torch.unsqueeze(hyps_pad_eos * mask, 2).to(torch.long)
score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
# mask padded part
score = score * mask
decoder_out = decoder_out.view(B, bz, T2, V)
if self.reverse_weight > 0:
# r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out,
# dim=-1)
r_decoder_out = r_decoder_out.view(B2, T2, V)
index = torch.unsqueeze(r_hyps_pad_eos * mask, 2).to(torch.long)
r_score = r_decoder_out.gather(2, index).squeeze(2)
r_score = r_score * mask
score = (score * (1 - self.reverse_weight) +
self.reverse_weight * r_score)
r_decoder_out = r_decoder_out.view(B, bz, T2, V)
score = torch.sum(score, axis=1) # B2
score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score
best_index = torch.argmax(score, dim=1)
if self.decoder_fastertransformer:
return decoder_out, best_index
else:
return best_index
def to_numpy(tensors):
out = []
if type(tensors) == torch.tensor:
tensors = [tensors]
for tensor in tensors:
if tensor.requires_grad:
tensor = tensor.detach().cpu().numpy()
else:
tensor = tensor.cpu().numpy()
out.append(tensor)
return out
def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
for a, b in zip(xlist, blist):
try:
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
except AssertionError as error:
if tolerate_small_mismatch:
print(error)
else:
raise
def export_offline_encoder(model, configs, args, logger, encoder_onnx_path):
bz = 1
seq_len = 1024
beam_size = args.beam_size
feature_size = configs["input_dim"]
speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32)
speech_lens = torch.randint(low=10,
high=seq_len,
size=(bz, ),
dtype=torch.int32)
encoder = Encoder(model.encoder, model.ctc, beam_size)
encoder.eval()
torch.onnx.export(
encoder,
(speech, speech_lens),
encoder_onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=["speech", "speech_lengths"],
output_names=[
"encoder_out",
"encoder_out_lens",
"ctc_log_probs",
"beam_log_probs",
"beam_log_probs_idx",
],
dynamic_axes=None,
# dynamic_axes={
# "speech": {
# 0: "B",
# 1: "T"
# },
# "speech_lengths": {
# 0: "B"
# },
# "encoder_out": {
# 0: "B",
# 1: "T_OUT"
# },
# "encoder_out_lens": {
# 0: "B"
# },
# "ctc_log_probs": {
# 0: "B",
# 1: "T_OUT"
# },
# "beam_log_probs": {
# 0: "B",
# 1: "T_OUT"
# },
# "beam_log_probs_idx": {
# 0: "B",
# 1: "T_OUT"
# },
# },
verbose=False,
dynamo=False,
)
fold_static_pulsar2_subgraphs(encoder_onnx_path)
simplify_pulsar2_onnx(encoder_onnx_path)
rewrite_pulsar2_bool_not(encoder_onnx_path)
with torch.no_grad():
o0, o1, o2, o3, o4 = encoder(speech, speech_lens)
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(encoder_onnx_path,
providers=providers)
ort_inputs = {
"speech": to_numpy(speech),
"speech_lengths": to_numpy(speech_lens),
}
ort_outs = ort_session.run(None, ort_inputs)
# check encoder output
test(to_numpy([o0, o1, o2, o3, o4]), ort_outs)
logger.info("export offline onnx encoder succeed!")
onnx_config = {
"beam_size": args.beam_size,
"reverse_weight": configs["model_conf"]["reverse_weight"],
"ctc_weight": configs["model_conf"]["ctc_weight"],
}
return onnx_config
def export_online_encoder(model, configs, args, logger, encoder_onnx_path):
decoding_chunk_size = args.decoding_chunk_size
subsampling = model.encoder.embed.subsampling_rate
context = model.encoder.embed.right_context + 1
decoding_window = (decoding_chunk_size - 1) * subsampling + context
batch_size = 1
audio_len = decoding_window
feature_size = configs["input_dim"]
output_size = configs["encoder_conf"]["output_size"]
num_layers = configs["encoder_conf"]["num_blocks"]
# in transformer the cnn module will not be available
transformer = False
cnn_module_kernel = configs["encoder_conf"].get("cnn_module_kernel", 1) - 1
if not cnn_module_kernel:
transformer = True
num_decoding_left_chunks = args.num_decoding_left_chunks
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
if configs["encoder"] == "squeezeformer":
encoder = StreamingSqueezeformerEncoder(model, required_cache_size,
args.beam_size)
elif configs["encoder"] == "efficientConformer":
encoder = StreamingEfficientConformerEncoder(model,
required_cache_size,
args.beam_size)
else:
encoder = StreamingEncoder(
model,
required_cache_size,
args.beam_size,
transformer,
args.return_ctc_logprobs,
)
encoder.eval()
# begin to export encoder
chunk_xs = torch.randn(batch_size,
audio_len,
feature_size,
dtype=torch.float32)
chunk_lens = torch.ones(batch_size, dtype=torch.int32) * audio_len
offset = torch.arange(0, batch_size, dtype=torch.int32).unsqueeze(1)
# (elayers, b, head, cache_t1, d_k * 2)
head = configs["encoder_conf"]["attention_heads"]
d_k = configs["encoder_conf"]["output_size"] // head
att_cache = torch.randn(
batch_size,
num_layers,
head,
required_cache_size,
d_k * 2,
dtype=torch.float32,
)
cnn_cache = torch.randn(
batch_size,
num_layers,
output_size,
cnn_module_kernel,
dtype=torch.float32,
)
cache_mask = torch.ones(batch_size,
1,
required_cache_size,
dtype=torch.float32)
input_names = [
"chunk_xs",
"chunk_lens",
"offset",
"att_cache",
"cnn_cache",
"cache_mask",
]
output_names = [
"log_probs",
"log_probs_idx",
"chunk_out",
"chunk_out_lens",
"r_offset",
"r_att_cache",
"r_cnn_cache",
"r_cache_mask",
]
if args.return_ctc_logprobs:
output_names = [
"ctc_log_probs",
"chunk_out",
"chunk_out_lens",
"r_offset",
"r_att_cache",
"r_cnn_cache",
"r_cache_mask",
]
input_tensors = (
chunk_xs,
chunk_lens,
offset,
att_cache,
cnn_cache,
cache_mask,
)
if transformer:
assert (args.return_ctc_logprobs is
False), "return_ctc_logprobs is not supported in transformer"
output_names.pop(6)
all_names = input_names + output_names
dynamic_axes = {}
for name in all_names:
# only the first dimension is dynamic
# all other dimension is fixed
dynamic_axes[name] = {0: "B"}
torch.onnx.export(
encoder,
input_tensors,
encoder_onnx_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
# dynamic_axes=dynamic_axes,
dynamic_axes=None,
verbose=False,
dynamo=False,
)
fold_static_pulsar2_subgraphs(encoder_onnx_path)
simplify_pulsar2_onnx(encoder_onnx_path)
rewrite_pulsar2_bool_not(encoder_onnx_path)
with torch.no_grad():
torch_outs = encoder(chunk_xs, chunk_lens, offset, att_cache,
cnn_cache, cache_mask)
if transformer:
torch_outs = list(torch_outs).pop(6)
ort_session = onnxruntime.InferenceSession(
encoder_onnx_path, providers=["CPUExecutionProvider"])
ort_inputs = {}
input_tensors = to_numpy(input_tensors)
for idx, name in enumerate(input_names):
ort_inputs[name] = input_tensors[idx]
if transformer:
del ort_inputs["cnn_cache"]
ort_outs = ort_session.run(None, ort_inputs)
test(to_numpy(torch_outs), ort_outs, rtol=1e-03, atol=1e-05)
logger.info("export to onnx streaming encoder succeed!")
onnx_config = {
"subsampling_rate": subsampling,
"context": context,
"decoding_chunk_size": decoding_chunk_size,
"num_decoding_left_chunks": num_decoding_left_chunks,
"beam_size": args.beam_size,
"feat_size": feature_size,
"decoding_window": decoding_window,
"cnn_module_kernel_cache": cnn_module_kernel,
"return_ctc_logprobs": args.return_ctc_logprobs,
}
return onnx_config
def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
decoder_fastertransformer):
bz, seq_len = 1, 32
beam_size = args.beam_size
decoder = Decoder(
model.decoder,
model.ctc_weight,
model.reverse_weight,
beam_size,
decoder_fastertransformer,
)
decoder.eval()
hyps_pad_sos_eos = torch.randint(low=3,
high=1000,
size=(bz, beam_size, seq_len),
dtype=torch.int32)
hyps_lens_sos = torch.randint(low=3,
high=seq_len,
size=(bz, beam_size),
dtype=torch.int32)
r_hyps_pad_sos_eos = torch.randint(low=3,
high=1000,
size=(bz, beam_size, seq_len),
dtype=torch.int32)
output_size = configs["encoder_conf"]["output_size"]
encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32)
encoder_out_lens = torch.randint(low=3,
high=seq_len,
size=(bz, ),
dtype=torch.int32)
ctc_score = torch.randn(bz, beam_size, dtype=torch.float32)
input_names = [
"encoder_out",
"encoder_out_lens",
"hyps_pad_sos_eos",
"hyps_lens_sos",
"r_hyps_pad_sos_eos",
"ctc_score",
]
output_names = ["best_index"]
if decoder_fastertransformer:
output_names.insert(0, "decoder_out")
torch.onnx.export(
decoder,
(
encoder_out,
encoder_out_lens,
hyps_pad_sos_eos,
hyps_lens_sos,
r_hyps_pad_sos_eos,
ctc_score,
),
decoder_onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=None,
# dynamic_axes={
# "encoder_out": {
# 0: "B",
# 1: "T"
# },
# "encoder_out_lens": {
# 0: "B"
# },
# "hyps_pad_sos_eos": {
# 0: "B",
# 2: "T2"
# },
# "hyps_lens_sos": {
# 0: "B"
# },
# "r_hyps_pad_sos_eos": {
# 0: "B",
# 2: "T2"
# },
# "ctc_score": {
# 0: "B"
# },
# "best_index": {
# 0: "B"
# },
# },
verbose=False,
dynamo=False,
)
fold_static_pulsar2_subgraphs(decoder_onnx_path)
simplify_pulsar2_onnx(decoder_onnx_path)
rewrite_pulsar2_bool_not(decoder_onnx_path)
rewrite_pulsar2_bool_and(decoder_onnx_path)
with torch.no_grad():
o0 = decoder(
encoder_out,
encoder_out_lens,
hyps_pad_sos_eos,
hyps_lens_sos,
r_hyps_pad_sos_eos,
ctc_score,
)
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(decoder_onnx_path,
providers=providers)
input_tensors = [
encoder_out,
encoder_out_lens,
hyps_pad_sos_eos,
hyps_lens_sos,
r_hyps_pad_sos_eos,
ctc_score,
]
ort_inputs = {}
input_tensors = to_numpy(input_tensors)
for idx, name in enumerate(input_names):
ort_inputs[name] = input_tensors[idx]
# if model.reverse weight == 0,
# the r_hyps_pad will be removed
# from the onnx decoder since it doen't play any role
if model.reverse_weight == 0:
del ort_inputs["r_hyps_pad_sos_eos"]
ort_outs = ort_session.run(None, ort_inputs)
# check decoder output
if decoder_fastertransformer:
test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05)
else:
test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05)
logger.info("export to onnx decoder succeed!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export x86_gpu model")
parser.add_argument(
"--pretrained_model_dir",
default=DEFAULT_PRETRAINED_MODEL_DIR,
help=("pretrained model directory containing train.yaml, final.pt, "
"and global_cmvn"),
)
parser.add_argument(
"--pretrained_model_url",
default=DEFAULT_PRETRAINED_MODEL_URL,
help="pretrained model tar.gz URL used when pretrained_model_dir is missing",
)
parser.add_argument(
"--reverse_weight",
default=-1.0,
type=float,
required=False,
help="reverse weight for bitransformer," +
"default value is in config file",
)
parser.add_argument(
"--ctc_weight",
default=-1.0,
type=float,
required=False,
help="ctc weight, default value is in config file",
)
parser.add_argument(
"--beam_size",
default=10,
type=int,
required=False,
help="beam size would be ctc output size",
)
parser.add_argument(
"--output_onnx_dir",
default="onnx_model",
help="output onnx encoder and decoder directory",
)
# arguments for streaming encoder
# parser.add_argument(
# "--streaming",
# action="store_true",
# help="whether to export streaming encoder, default false",
# )
parser.add_argument(
"--decoding_chunk_size",
default=16,
type=int,
required=False,
help="the decoding chunk size, <=0 is not supported",
)
parser.add_argument(
"--num_decoding_left_chunks",
default=5,
type=int,
required=False,
help="number of left chunks, <= 0 is not supported",
)
parser.add_argument(
"--decoder_fastertransformer",
action="store_true",
help="return decoder_out and best_index for ft",
)
parser.add_argument(
"--return_ctc_logprobs",
action="store_true",
help="return full ctc_log_probs for TLG streaming encoder",
)
args = parser.parse_args()
prepare_pretrained_model(args)
torch.manual_seed(0)
torch.set_printoptions(precision=10)
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
if os.path.exists(args.cmvn_file):
if 'cmvn' not in configs:
configs['cmvn'] = "global_cmvn"
configs['cmvn_conf'] = {}
else:
assert configs['cmvn'] == "global_cmvn"
assert configs['cmvn_conf'] is not None
configs['cmvn_conf']["cmvn_file"] = args.cmvn_file
configs['cmvn_conf'].setdefault(
"is_json_cmvn", configs.get("is_json_cmvn", True))
elif configs.get('cmvn', None) == 'global_cmvn':
raise FileNotFoundError(
f"Expected global_cmvn in pretrained model dir: {args.cmvn_file}")
if (args.reverse_weight != -1.0
and "reverse_weight" in configs["model_conf"]):
configs["model_conf"]["reverse_weight"] = args.reverse_weight
print("Update reverse weight to", args.reverse_weight)
if args.ctc_weight != -1:
print("Update ctc weight to ", args.ctc_weight)
configs["model_conf"]["ctc_weight"] = args.ctc_weight
configs["encoder_conf"]["use_dynamic_chunk"] = False
model, configs = init_model(args, configs)
model.eval()
if not os.path.exists(args.output_onnx_dir):
os.mkdir(args.output_onnx_dir)
export_enc_func = None
# if args.streaming:
assert args.decoding_chunk_size > 0
assert args.num_decoding_left_chunks > 0
export_enc_func = export_online_encoder
encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_online.onnx")
onnx_config = export_enc_func(model, configs, args, logger,
encoder_onnx_path)
# else
export_enc_func = export_offline_encoder
encoder_onnx_path = os.path.join(args.output_onnx_dir, "encoder_offline.onnx")
onnx_config = export_enc_func(model, configs, args, logger,
encoder_onnx_path)
decoder_onnx_path = os.path.join(args.output_onnx_dir, "decoder.onnx")
export_rescoring_decoder(
model,
configs,
args,
logger,
decoder_onnx_path,
args.decoder_fastertransformer,
)
config_dir = os.path.join(args.output_onnx_dir, "config.yaml")
with open(config_dir, "w") as out:
yaml.dump(onnx_config, out)