| |
| |
| |
|
|
| """ |
| This script exports a CTC model from PyTorch to ONNX. |
| |
| Note that the model is trained using both transducer and CTC loss. This script |
| exports only the CTC head. |
| |
| We use the pre-trained model from |
| https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 |
| as an example to show how to use this file. |
| |
| 1. Download the pre-trained model |
| |
| cd egs/librispeech/ASR |
| |
| repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 |
| GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| repo=$(basename $repo_url) |
| |
| pushd $repo |
| git lfs pull --include "exp/pretrained.pt" |
| |
| cd exp |
| ln -s pretrained.pt epoch-99.pt |
| popd |
| |
| 2. Export the model to ONNX |
| |
| ./zipformer/export-onnx-ctc.py \ |
| --use-transducer 0 \ |
| --use-ctc 1 \ |
| --tokens $repo/data/lang_bpe_500/tokens.txt \ |
| --use-averaged-model 0 \ |
| --epoch 99 \ |
| --avg 1 \ |
| --exp-dir $repo/exp \ |
| --num-encoder-layers "2,2,3,4,3,2" \ |
| --downsampling-factor "1,2,4,8,4,2" \ |
| --feedforward-dim "512,768,1024,1536,1024,768" \ |
| --num-heads "4,4,4,8,4,4" \ |
| --encoder-dim "192,256,384,512,384,256" \ |
| --query-head-dim 32 \ |
| --value-head-dim 12 \ |
| --pos-head-dim 4 \ |
| --pos-dim 48 \ |
| --encoder-unmasked-dim "192,192,256,256,256,192" \ |
| --cnn-module-kernel "31,31,15,15,15,31" \ |
| --decoder-dim 512 \ |
| --joiner-dim 512 \ |
| --causal False \ |
| --chunk-size 16 \ |
| --left-context-frames 128 |
| |
| It will generate the following 2 files inside $repo/exp: |
| |
| - model.onnx |
| - model.int8.onnx |
| |
| See ./onnx_pretrained_ctc.py for how to |
| use the exported ONNX models. |
| """ |
|
|
| import argparse |
| import logging |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import k2 |
| import onnx |
| import torch |
| import torch.nn as nn |
| from onnxruntime.quantization import QuantType, quantize_dynamic |
| from scaling_converter import convert_scaled_to_non_scaled |
| from train import add_model_arguments, get_model, get_params |
| from zipformer import Zipformer2 |
|
|
| from icefall.checkpoint import ( |
| average_checkpoints, |
| average_checkpoints_with_averaged_model, |
| find_checkpoints, |
| load_checkpoint, |
| ) |
| from icefall.utils import make_pad_mask, num_tokens, str2bool |
|
|
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--epoch", |
| type=int, |
| default=28, |
| help="""It specifies the checkpoint to use for averaging. |
| Note: Epoch counts from 0. |
| You can specify --avg to use more checkpoints for model averaging.""", |
| ) |
|
|
| parser.add_argument( |
| "--iter", |
| type=int, |
| default=0, |
| help="""If positive, --epoch is ignored and it |
| will use the checkpoint exp_dir/checkpoint-iter.pt. |
| You can specify --avg to use more checkpoints for model averaging. |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--avg", |
| type=int, |
| default=15, |
| help="Number of checkpoints to average. Automatically select " |
| "consecutive checkpoints before the checkpoint specified by " |
| "'--epoch' and '--iter'", |
| ) |
|
|
| parser.add_argument( |
| "--use-averaged-model", |
| type=str2bool, |
| default=True, |
| help="Whether to load averaged model. Currently it only supports " |
| "using --epoch. If True, it would decode with the averaged model " |
| "over the epoch range from `epoch-avg` (excluded) to `epoch`." |
| "Actually only the models with epoch number of `epoch-avg` and " |
| "`epoch` are loaded for averaging. ", |
| ) |
|
|
| parser.add_argument( |
| "--exp-dir", |
| type=str, |
| default="zipformer/exp", |
| help="""It specifies the directory where all training related |
| files, e.g., checkpoints, log, etc, are saved |
| """, |
| ) |
|
|
| parser.add_argument( |
| "--tokens", |
| type=str, |
| default="data/lang_bpe_500/tokens.txt", |
| help="Path to the tokens.txt", |
| ) |
|
|
| parser.add_argument( |
| "--context-size", |
| type=int, |
| default=2, |
| help="The context size in the decoder. 1 means bigram; 2 means tri-gram", |
| ) |
|
|
| parser.add_argument( |
| "--fp16", |
| type=str2bool, |
| default=False, |
| help="Whether to export models in fp16", |
| ) |
|
|
| add_model_arguments(parser) |
|
|
| return parser |
|
|
|
|
| def add_meta_data(filename: str, meta_data: Dict[str, str]): |
| """Add meta data to an ONNX model. It is changed in-place. |
| |
| Args: |
| filename: |
| Filename of the ONNX model to be changed. |
| meta_data: |
| Key-value pairs. |
| """ |
| model = onnx.load(filename) |
| for key, value in meta_data.items(): |
| meta = model.metadata_props.add() |
| meta.key = key |
| meta.value = value |
|
|
| onnx.save(model, filename) |
|
|
|
|
| def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): |
| import onnxmltools |
| from onnxmltools.utils.float16_converter import convert_float_to_float16 |
|
|
| onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) |
| onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) |
| onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) |
|
|
|
|
| class OnnxModel(nn.Module): |
| """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" |
|
|
| def __init__( |
| self, |
| encoder: Zipformer2, |
| encoder_embed: nn.Module, |
| ctc_output: nn.Module, |
| ): |
| """ |
| Args: |
| encoder: |
| A Zipformer encoder. |
| encoder_embed: |
| The first downsampling layer for zipformer. |
| """ |
| super().__init__() |
| self.encoder = encoder |
| self.encoder_embed = encoder_embed |
| self.ctc_output = ctc_output |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| x_lens: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Please see the help information of Zipformer.forward |
| |
| Args: |
| x: |
| A 3-D tensor of shape (N, T, C) |
| x_lens: |
| A 1-D tensor of shape (N,). Its dtype is torch.int64 |
| Returns: |
| Return a tuple containing: |
| - log_probs, a 3-D tensor of shape (N, T', vocab_size) |
| - log_probs_len, a 1-D int64 tensor of shape (N,) |
| """ |
| x, x_lens = self.encoder_embed(x, x_lens) |
| src_key_padding_mask = make_pad_mask(x_lens) |
| x = x.permute(1, 0, 2) |
| encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) |
| encoder_out = encoder_out.permute(1, 0, 2) |
| log_probs = self.ctc_output(encoder_out) |
|
|
| return log_probs, log_probs_len |
|
|
|
|
| def export_ctc_model_onnx( |
| model: OnnxModel, |
| filename: str, |
| opset_version: int = 11, |
| ) -> None: |
| """Export the given model to ONNX format. |
| The exported model has two inputs: |
| |
| - x, a tensor of shape (N, T, C); dtype is torch.float32 |
| - x_lens, a tensor of shape (N,); dtype is torch.int64 |
| |
| and it has two outputs: |
| |
| - log_probs, a tensor of shape (N, T', joiner_dim) |
| - log_probs_len, a tensor of shape (N,) |
| |
| Args: |
| model: |
| The input model |
| filename: |
| The filename to save the exported ONNX model. |
| opset_version: |
| The opset version to use. |
| """ |
| x = torch.zeros(1, 100, 80, dtype=torch.float32) |
| x_lens = torch.tensor([100], dtype=torch.int64) |
|
|
| model = torch.jit.trace(model, (x, x_lens)) |
|
|
| torch.onnx.export( |
| model, |
| (x, x_lens), |
| filename, |
| verbose=False, |
| opset_version=opset_version, |
| input_names=["x", "x_lens"], |
| output_names=["log_probs", "log_probs_len"], |
| dynamic_axes={ |
| "x": {0: "N", 1: "T"}, |
| "x_lens": {0: "N"}, |
| "log_probs": {0: "N", 1: "T"}, |
| "log_probs_len": {0: "N"}, |
| }, |
| ) |
|
|
| meta_data = { |
| "model_type": "zipformer2_ctc", |
| "version": "1", |
| "model_author": "k2-fsa", |
| "comment": "non-streaming zipformer2 CTC", |
| } |
| logging.info(f"meta_data: {meta_data}") |
|
|
| add_meta_data(filename=filename, meta_data=meta_data) |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| args = get_parser().parse_args() |
| args.exp_dir = Path(args.exp_dir) |
|
|
| params = get_params() |
| params.update(vars(args)) |
|
|
| device = torch.device("cpu") |
| if torch.cuda.is_available(): |
| device = torch.device("cuda", 0) |
|
|
| logging.info(f"device: {device}") |
|
|
| token_table = k2.SymbolTable.from_file(params.tokens) |
| params.blank_id = token_table["<blk>"] |
| params.vocab_size = num_tokens(token_table) + 1 |
|
|
| logging.info(params) |
|
|
| logging.info("About to create model") |
| model = get_model(params) |
|
|
| model.to(device) |
|
|
| if not params.use_averaged_model: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints(filenames, device=device), strict=False |
| ) |
| elif params.avg == 1: |
| load_checkpoint( |
| f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False |
| ) |
| else: |
| start = params.epoch - params.avg + 1 |
| filenames = [] |
| for i in range(start, params.epoch + 1): |
| if i >= 1: |
| filenames.append(f"{params.exp_dir}/epoch-{i}.pt") |
| logging.info(f"averaging {filenames}") |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints(filenames, device=device), strict=False |
| ) |
| else: |
| if params.iter > 0: |
| filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ |
| : params.avg + 1 |
| ] |
| if len(filenames) == 0: |
| raise ValueError( |
| f"No checkpoints found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| elif len(filenames) < params.avg + 1: |
| raise ValueError( |
| f"Not enough checkpoints ({len(filenames)}) found for" |
| f" --iter {params.iter}, --avg {params.avg}" |
| ) |
| filename_start = filenames[-1] |
| filename_end = filenames[0] |
| logging.info( |
| "Calculating the averaged model over iteration checkpoints" |
| f" from {filename_start} (excluded) to {filename_end}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ), |
| strict=False, |
| ) |
| else: |
| assert params.avg > 0, params.avg |
| start = params.epoch - params.avg |
| assert start >= 1, start |
| filename_start = f"{params.exp_dir}/epoch-{start}.pt" |
| filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" |
| logging.info( |
| f"Calculating the averaged model over epoch range from " |
| f"{start} (excluded) to {params.epoch}" |
| ) |
| model.to(device) |
| model.load_state_dict( |
| average_checkpoints_with_averaged_model( |
| filename_start=filename_start, |
| filename_end=filename_end, |
| device=device, |
| ), |
| strict=False, |
| ) |
|
|
| model.to("cpu") |
| model.eval() |
|
|
| convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) |
|
|
| model = OnnxModel( |
| encoder=model.encoder, |
| encoder_embed=model.encoder_embed, |
| ctc_output=model.ctc_output, |
| ) |
|
|
| num_param = sum([p.numel() for p in model.parameters()]) |
| logging.info(f"num parameters: {num_param}") |
|
|
| opset_version = 13 |
|
|
| logging.info("Exporting ctc model") |
| filename = params.exp_dir / "model.onnx" |
| export_ctc_model_onnx( |
| model, |
| filename, |
| opset_version=opset_version, |
| ) |
| logging.info(f"Exported to {filename}") |
|
|
| |
| |
|
|
| logging.info("Generate int8 quantization models") |
|
|
| filename_int8 = params.exp_dir / "model.int8.onnx" |
| quantize_dynamic( |
| model_input=filename, |
| model_output=filename_int8, |
| op_types_to_quantize=["MatMul"], |
| weight_type=QuantType.QInt8, |
| ) |
|
|
| if params.fp16: |
| filename_fp16 = params.exp_dir / "model.fp16.onnx" |
| export_onnx_fp16(filename, filename_fp16) |
|
|
|
|
| if __name__ == "__main__": |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
| logging.basicConfig(format=formatter, level=logging.INFO) |
| main() |
|
|