| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This script checks that exported onnx models produce the same output |
| with the given torchscript model for the same input. |
| |
| We use the pre-trained model from |
| https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 |
| as an example to show how to use this file. |
| |
| 1. Download the pre-trained model |
| |
| cd egs/librispeech/ASR |
| |
| repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 |
| GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| repo=$(basename $repo_url) |
| |
| pushd $repo |
| git lfs pull --include "exp/pretrained.pt" |
| |
| cd exp |
| ln -s pretrained.pt epoch-99.pt |
| popd |
| |
| 2. Export the model via torchscript (torch.jit.script()) |
| |
| ./zipformer/export.py \ |
| --tokens $repo/data/lang_bpe_500/tokens.txt \ |
| --use-averaged-model 0 \ |
| --epoch 99 \ |
| --avg 1 \ |
| --exp-dir $repo/exp/ \ |
| --jit 1 |
| |
| It will generate the following file in $repo/exp: |
| - jit_script.pt |
| |
| 3. Export the model to ONNX |
| |
| ./zipformer/export-onnx.py \ |
| --tokens $repo/data/lang_bpe_500/tokens.txt \ |
| --use-averaged-model 0 \ |
| --epoch 99 \ |
| --avg 1 \ |
| --exp-dir $repo/exp/ |
| |
| It will generate the following 3 files inside $repo/exp: |
| |
| - encoder-epoch-99-avg-1.onnx |
| - decoder-epoch-99-avg-1.onnx |
| - joiner-epoch-99-avg-1.onnx |
| |
| 4. Run this file |
| |
| ./zipformer/onnx_check.py \ |
| --jit-filename $repo/exp/jit_script.pt \ |
| --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ |
| --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ |
| --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx |
| """ |
|
|
| import argparse |
| import logging |
|
|
| import torch |
| from onnx_pretrained import OnnxModel |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--jit-filename", |
| required=True, |
| type=str, |
| help="Path to the torchscript model", |
| ) |
|
|
| parser.add_argument( |
| "--onnx-encoder-filename", |
| required=True, |
| type=str, |
| help="Path to the onnx encoder model", |
| ) |
|
|
| parser.add_argument( |
| "--onnx-decoder-filename", |
| required=True, |
| type=str, |
| help="Path to the onnx decoder model", |
| ) |
|
|
| parser.add_argument( |
| "--onnx-joiner-filename", |
| required=True, |
| type=str, |
| help="Path to the onnx joiner model", |
| ) |
|
|
| return parser |
|
|
|
|
| def test_encoder( |
| torch_model: torch.jit.ScriptModule, |
| onnx_model: OnnxModel, |
| ): |
| C = 80 |
| for i in range(3): |
| N = torch.randint(low=1, high=20, size=(1,)).item() |
| T = torch.randint(low=30, high=50, size=(1,)).item() |
| logging.info(f"test_encoder: iter {i}, N={N}, T={T}") |
|
|
| x = torch.rand(N, T, C) |
| x_lens = torch.randint(low=30, high=T + 1, size=(N,)) |
| x_lens[0] = T |
|
|
| torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) |
| torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) |
|
|
| onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) |
|
|
| assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( |
| (torch_encoder_out - onnx_encoder_out).abs().max() |
| ) |
|
|
|
|
| def test_decoder( |
| torch_model: torch.jit.ScriptModule, |
| onnx_model: OnnxModel, |
| ): |
| context_size = onnx_model.context_size |
| vocab_size = onnx_model.vocab_size |
| for i in range(10): |
| N = torch.randint(1, 100, size=(1,)).item() |
| logging.info(f"test_decoder: iter {i}, N={N}") |
| x = torch.randint( |
| low=1, |
| high=vocab_size, |
| size=(N, context_size), |
| dtype=torch.int64, |
| ) |
| torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) |
| torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) |
| torch_decoder_out = torch_decoder_out.squeeze(1) |
|
|
| onnx_decoder_out = onnx_model.run_decoder(x) |
| assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( |
| (torch_decoder_out - onnx_decoder_out).abs().max() |
| ) |
|
|
|
|
| def test_joiner( |
| torch_model: torch.jit.ScriptModule, |
| onnx_model: OnnxModel, |
| ): |
| encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] |
| decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] |
| for i in range(10): |
| N = torch.randint(1, 100, size=(1,)).item() |
| logging.info(f"test_joiner: iter {i}, N={N}") |
| encoder_out = torch.rand(N, encoder_dim) |
| decoder_out = torch.rand(N, decoder_dim) |
|
|
| projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) |
| projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) |
|
|
| torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) |
| onnx_joiner_out = onnx_model.run_joiner( |
| projected_encoder_out, projected_decoder_out |
| ) |
|
|
| assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( |
| (torch_joiner_out - onnx_joiner_out).abs().max() |
| ) |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| args = get_parser().parse_args() |
| logging.info(vars(args)) |
|
|
| torch_model = torch.jit.load(args.jit_filename) |
|
|
| onnx_model = OnnxModel( |
| encoder_model_filename=args.onnx_encoder_filename, |
| decoder_model_filename=args.onnx_decoder_filename, |
| joiner_model_filename=args.onnx_joiner_filename, |
| ) |
|
|
| logging.info("Test encoder") |
| test_encoder(torch_model, onnx_model) |
|
|
| logging.info("Test decoder") |
| test_decoder(torch_model, onnx_model) |
|
|
| logging.info("Test joiner") |
| test_joiner(torch_model, onnx_model) |
| logging.info("Finished checking ONNX models") |
|
|
|
|
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| torch._C._jit_set_profiling_executor(False) |
| torch._C._jit_set_profiling_mode(False) |
| torch._C._set_graph_executor_optimize(False) |
| if __name__ == "__main__": |
| torch.manual_seed(20220727) |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
| logging.basicConfig(format=formatter, level=logging.INFO) |
| main() |
|
|