| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Script for calibrating a pretrained ASR model for quantization |
| """ |
|
|
| from argparse import ArgumentParser |
|
|
| import torch |
| from omegaconf import open_dict |
|
|
| from nemo.collections.asr.models import EncDecCTCModel |
| from nemo.utils import logging |
|
|
| try: |
| from pytorch_quantization import calib |
| from pytorch_quantization import nn as quant_nn |
| from pytorch_quantization import quant_modules |
| from pytorch_quantization.tensor_quant import QuantDescriptor |
| except ImportError: |
| raise ImportError( |
| "pytorch-quantization is not installed. Install from " |
| "https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." |
| ) |
|
|
| can_gpu = torch.cuda.is_available() |
|
|
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--asr_model", |
| type=str, |
| default="QuartzNet15x5Base-En", |
| required=True, |
| help="Pass: 'QuartzNet15x5Base-En'", |
| ) |
| parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") |
| parser.add_argument("--batch_size", type=int, default=256) |
| parser.add_argument( |
| "--dont_normalize_text", |
| default=False, |
| action='store_false', |
| help="Turn off trasnscript normalization. Recommended for non-English.", |
| ) |
| parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") |
| parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") |
| parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) |
| parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") |
| parser.set_defaults(amp=False) |
|
|
| args = parser.parse_args() |
| torch.set_grad_enabled(False) |
|
|
| |
| quant_desc_input = QuantDescriptor(calib_method=args.calibrator) |
| quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) |
| quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) |
| quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) |
|
|
| if args.asr_model.endswith('.nemo'): |
| logging.info(f"Using local ASR model from {args.asr_model}") |
| asr_model_cfg = EncDecCTCModel.restore_from(restore_path=args.asr_model, return_config=True) |
| with open_dict(asr_model_cfg): |
| asr_model_cfg.encoder.quantize = True |
| asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg) |
|
|
| else: |
| logging.info(f"Using NGC cloud ASR model {args.asr_model}") |
| asr_model_cfg = EncDecCTCModel.from_pretrained(model_name=args.asr_model, return_config=True) |
| with open_dict(asr_model_cfg): |
| asr_model_cfg.encoder.quantize = True |
| asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, override_config_path=asr_model_cfg) |
|
|
| asr_model.setup_test_data( |
| test_data_config={ |
| 'sample_rate': 16000, |
| 'manifest_filepath': args.dataset, |
| 'labels': asr_model.decoder.vocabulary, |
| 'batch_size': args.batch_size, |
| 'normalize_transcripts': args.dont_normalize_text, |
| 'shuffle': True, |
| } |
| ) |
| asr_model.preprocessor.featurizer.dither = 0.0 |
| asr_model.preprocessor.featurizer.pad_to = 0 |
| if can_gpu: |
| asr_model = asr_model.cuda() |
| asr_model.eval() |
|
|
| |
| for name, module in asr_model.named_modules(): |
| if isinstance(module, quant_nn.TensorQuantizer): |
| if module._calibrator is not None: |
| module.disable_quant() |
| module.enable_calib() |
| else: |
| module.disable() |
|
|
| for i, test_batch in enumerate(asr_model.test_dataloader()): |
| if can_gpu: |
| test_batch = [x.cuda() for x in test_batch] |
| with torch.amp.autocast(asr_model.device.type, enabled=args.amp): |
| _ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) |
| if i >= args.num_calib_batch: |
| break |
|
|
| |
| model_name = args.asr_model.replace(".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model |
| if not args.calibrator == "histogram": |
| compute_amax(asr_model, method="max") |
| asr_model.save_to(F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") |
| else: |
| for percentile in args.percentile: |
| print(F"{percentile} percentile calibration") |
| compute_amax(asr_model, method="percentile") |
| asr_model.save_to(F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo") |
|
|
| for method in ["mse", "entropy"]: |
| print(F"{method} calibration") |
| compute_amax(asr_model, method=method) |
| asr_model.save_to(F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo") |
|
|
|
|
| def compute_amax(model, **kwargs): |
| for name, module in model.named_modules(): |
| if isinstance(module, quant_nn.TensorQuantizer): |
| if module._calibrator is not None: |
| if isinstance(module._calibrator, calib.MaxCalibrator): |
| module.load_calib_amax() |
| else: |
| module.load_calib_amax(**kwargs) |
| print(F"{name:40}: {module}") |
| if can_gpu: |
| model.cuda() |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|