| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Helper functions for training models with pytorch-quantization""" |
| | import logging |
| | import re |
| |
|
| | import pytorch_quantization |
| | import pytorch_quantization.nn as quant_nn |
| | import torch |
| | from pytorch_quantization import calib |
| | from pytorch_quantization.tensor_quant import QuantDescriptor |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | name_width = 50 |
| | qname_width = 70 |
| |
|
| | |
| |
|
| |
|
| | def add_arguments(parser): |
| | """Add arguments to parser for functions defined in quant_trainer.""" |
| |
|
| | group = parser.add_argument_group("quant_trainer arguments") |
| | group.add_argument("--wprec", type=int, default=8, help="weight precision") |
| | group.add_argument("--aprec", type=int, default=8, help="activation precision") |
| | group.add_argument("--quant-per-tensor", action="store_true", help="per tensor weight scaling") |
| | group.add_argument("--quant-disable", action="store_true", help="disable all quantizers") |
| | group.add_argument("--quant-disable-embeddings", action="store_true", help="disable all embeddings quantizers") |
| | group.add_argument("--quant-disable-keyword", type=str, nargs="+", help="disable quantizers by keyword") |
| | group.add_argument("--quant-disable-layer-module", type=str, help="disable quantizers by keyword under layer.") |
| | group.add_argument("--quant-enable-layer-module", type=str, help="enable quantizers by keyword under layer") |
| | group.add_argument("--calibrator", default="max", help="which quantization range calibrator to use") |
| | group.add_argument("--percentile", default=None, type=float, help="percentile for PercentileCalibrator") |
| | group.add_argument("--fuse-qkv", action="store_true", help="use the same scale factor for qkv") |
| | group.add_argument("--clip-gelu", metavar="N", type=float, help="clip gelu output maximum value to N") |
| | group.add_argument( |
| | "--recalibrate-weights", |
| | action="store_true", |
| | help=( |
| | "recalibrate weight amaxes by taking the max of the weights." |
| | " amaxes will be computed with the current quantization granularity (axis)." |
| | ), |
| | ) |
| |
|
| |
|
| | def set_default_quantizers(args): |
| | """Set default quantizers before creating the model.""" |
| |
|
| | if args.calibrator == "max": |
| | calib_method = "max" |
| | elif args.calibrator == "percentile": |
| | if args.percentile is None: |
| | raise ValueError("Specify --percentile when using percentile calibrator") |
| | calib_method = "histogram" |
| | elif args.calibrator == "mse": |
| | calib_method = "histogram" |
| | else: |
| | raise ValueError(f"Invalid calibrator {args.calibrator}") |
| |
|
| | input_desc = QuantDescriptor(num_bits=args.aprec, calib_method=calib_method) |
| | weight_desc = QuantDescriptor(num_bits=args.wprec, axis=(None if args.quant_per_tensor else (0,))) |
| | quant_nn.QuantLinear.set_default_quant_desc_input(input_desc) |
| | quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc) |
| |
|
| |
|
| | def configure_model(model, args, calib=False, eval=False): |
| | """Function called before the training loop.""" |
| |
|
| | logger.info("Configuring Model for Quantization") |
| | logger.info(f"using quantization package {pytorch_quantization.__file__}") |
| |
|
| | if not calib: |
| | if args.quant_disable_embeddings: |
| | set_quantizer_by_name(model, ["embeddings"], which="weight", _disabled=True) |
| |
|
| | if args.quant_disable: |
| | set_quantizer_by_name(model, [""], _disabled=True) |
| |
|
| | if args.quant_disable_keyword: |
| | set_quantizer_by_name(model, args.quant_disable_keyword, _disabled=True) |
| |
|
| | if args.quant_disable_layer_module: |
| | set_quantizer_by_name(model, [r"layer.\d+." + args.quant_disable_layer_module], _disabled=True) |
| |
|
| | if args.quant_enable_layer_module: |
| | set_quantizer_by_name(model, [r"layer.\d+." + args.quant_enable_layer_module], _disabled=False) |
| |
|
| | if args.recalibrate_weights: |
| | recalibrate_weights(model) |
| |
|
| | if args.fuse_qkv: |
| | fuse_qkv(model, args) |
| |
|
| | if args.clip_gelu: |
| | clip_gelu(model, args.clip_gelu) |
| |
|
| | |
| | print_quant_summary(model) |
| |
|
| |
|
| | def enable_calibration(model): |
| | """Enable calibration of all *_input_quantizer modules in model.""" |
| |
|
| | logger.info("Enabling Calibration") |
| | for name, module in model.named_modules(): |
| | if name.endswith("_quantizer"): |
| | if module._calibrator is not None: |
| | module.disable_quant() |
| | module.enable_calib() |
| | else: |
| | module.disable() |
| | logger.info(f"{name:80}: {module}") |
| |
|
| |
|
| | def finish_calibration(model, args): |
| | """Disable calibration and load amax for all "*_input_quantizer modules in model.""" |
| |
|
| | logger.info("Loading calibrated amax") |
| | for name, module in model.named_modules(): |
| | if name.endswith("_quantizer"): |
| | if module._calibrator is not None: |
| | if isinstance(module._calibrator, calib.MaxCalibrator): |
| | module.load_calib_amax() |
| | else: |
| | module.load_calib_amax("percentile", percentile=args.percentile) |
| | module.enable_quant() |
| | module.disable_calib() |
| | else: |
| | module.enable() |
| | model.cuda() |
| | print_quant_summary(model) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def fuse_qkv(model, args): |
| | """Adjust quantization ranges to match an implementation where the QKV projections are implemented with a single GEMM. |
| | Force the weight and output scale factors to match by taking the max of (Q,K,V). |
| | """ |
| |
|
| | def fuse3(qq, qk, qv): |
| | for mod in [qq, qk, qv]: |
| | if not hasattr(mod, "_amax"): |
| | print(" WARNING: NO AMAX BUFFER") |
| | return |
| | q = qq._amax.detach().item() |
| | k = qk._amax.detach().item() |
| | v = qv._amax.detach().item() |
| |
|
| | amax = max(q, k, v) |
| | qq._amax.fill_(amax) |
| | qk._amax.fill_(amax) |
| | qv._amax.fill_(amax) |
| | logger.info(f" q={q:5.2f} k={k:5.2f} v={v:5.2f} -> {amax:5.2f}") |
| |
|
| | for name, mod in model.named_modules(): |
| | if name.endswith(".attention.self"): |
| | logger.info(f"FUSE_QKV: {name:{name_width}}") |
| | fuse3(mod.matmul_q_input_quantizer, mod.matmul_k_input_quantizer, mod.matmul_v_input_quantizer) |
| | if args.quant_per_tensor: |
| | fuse3(mod.query._weight_quantizer, mod.key._weight_quantizer, mod.value._weight_quantizer) |
| |
|
| |
|
| | def clip_gelu(model, maxval): |
| | """Clip activations generated by GELU to maxval when quantized. |
| | Implemented by adjusting the amax of the following input_quantizer. |
| | """ |
| |
|
| | for name, mod in model.named_modules(): |
| | if name.endswith(".output.dense") and not name.endswith("attention.output.dense"): |
| | amax_init = mod._input_quantizer._amax.data.detach().item() |
| | mod._input_quantizer._amax.data.detach().clamp_(max=maxval) |
| | amax = mod._input_quantizer._amax.data.detach().item() |
| | logger.info(f"CLIP_GELU: {name:{name_width}} amax: {amax_init:5.2f} -> {amax:5.2f}") |
| |
|
| |
|
| | def expand_amax(model): |
| | """Expand per-tensor amax to be per channel, where each channel is assigned the per-tensor amax.""" |
| |
|
| | for name, mod in model.named_modules(): |
| | if hasattr(mod, "_weight_quantizer") and mod._weight_quantizer.axis is not None: |
| | k = mod.weight.shape[0] |
| | amax = mod._weight_quantizer._amax.detach() |
| | mod._weight_quantizer._amax = torch.ones(k, dtype=amax.dtype, device=amax.device) * amax |
| | print(f"expanding {name} {amax} -> {mod._weight_quantizer._amax}") |
| |
|
| |
|
| | def recalibrate_weights(model): |
| | """Performs max calibration on the weights and updates amax.""" |
| |
|
| | for name, mod in model.named_modules(): |
| | if hasattr(mod, "_weight_quantizer"): |
| | if not hasattr(mod.weight_quantizer, "_amax"): |
| | print("RECALIB: {name:{name_width}} WARNING: NO AMAX BUFFER") |
| | continue |
| |
|
| | |
| | |
| | axis_set = set() if mod._weight_quantizer.axis is None else set(mod._weight_quantizer.axis) |
| | reduce_axis = set(range(len(mod.weight.size()))) - axis_set |
| | amax = pytorch_quantization.utils.reduce_amax(mod.weight, axis=reduce_axis, keepdims=True).detach() |
| | logger.info(f"RECALIB: {name:{name_width}} {mod._weight_quantizer._amax.flatten()} -> {amax.flatten()}") |
| | mod._weight_quantizer._amax = amax |
| |
|
| |
|
| | def print_model_summary(model, name_width=25, line_width=180, ignore=None): |
| | """Print model quantization configuration.""" |
| |
|
| | if ignore is None: |
| | ignore = [] |
| | elif not isinstance(ignore, list): |
| | ignore = [ignore] |
| |
|
| | name_width = 0 |
| | for name, mod in model.named_modules(): |
| | if not hasattr(mod, "weight"): |
| | continue |
| | name_width = max(name_width, len(name)) |
| |
|
| | for name, mod in model.named_modules(): |
| | input_q = getattr(mod, "_input_quantizer", None) |
| | weight_q = getattr(mod, "_weight_quantizer", None) |
| | if not hasattr(mod, "weight"): |
| | continue |
| | if type(mod) in ignore: |
| | continue |
| | if [True for s in ignore if type(s) is str and s in name]: |
| | continue |
| | act_str = f"Act:{input_q.extra_repr()}" |
| | wgt_str = f"Wgt:{weight_q.extra_repr()}" |
| | s = f"{name:{name_width}} {act_str} {wgt_str}" |
| | if len(s) <= line_width: |
| | logger.info(s) |
| | else: |
| | logger.info(f"{name:{name_width}} {act_str}") |
| | logger.info(f'{" ":{name_width}} {wgt_str}') |
| |
|
| |
|
| | def print_quant_summary(model): |
| | """Print summary of all quantizer modules in the model.""" |
| |
|
| | count = 0 |
| | for name, mod in model.named_modules(): |
| | if isinstance(mod, pytorch_quantization.nn.TensorQuantizer): |
| | print(f"{name:80} {mod}") |
| | count += 1 |
| | print(f"{count} TensorQuantizers found in model") |
| |
|
| |
|
| | def set_quantizer(name, mod, quantizer, k, v): |
| | """Set attributes for mod.quantizer.""" |
| |
|
| | quantizer_mod = getattr(mod, quantizer, None) |
| | if quantizer_mod is not None: |
| | assert hasattr(quantizer_mod, k) |
| | setattr(quantizer_mod, k, v) |
| | else: |
| | logger.warning(f"{name} has no {quantizer}") |
| |
|
| |
|
| | def set_quantizers(name, mod, which="both", **kwargs): |
| | """Set quantizer attributes for mod.""" |
| |
|
| | s = f"Warning: changing {which} quantizers of {name:{qname_width}}" |
| | for k, v in kwargs.items(): |
| | s += f" {k}={v}" |
| | if which in ["input", "both"]: |
| | set_quantizer(name, mod, "_input_quantizer", k, v) |
| | if which in ["weight", "both"]: |
| | set_quantizer(name, mod, "_weight_quantizer", k, v) |
| | logger.info(s) |
| |
|
| |
|
| | def set_quantizer_by_name(model, names, **kwargs): |
| | """Set quantizer attributes for layers where name contains a substring in names.""" |
| |
|
| | for name, mod in model.named_modules(): |
| | if hasattr(mod, "_input_quantizer") or hasattr(mod, "_weight_quantizer"): |
| | for n in names: |
| | if re.search(n, name): |
| | set_quantizers(name, mod, **kwargs) |
| | elif name.endswith("_quantizer"): |
| | for n in names: |
| | if re.search(n, name): |
| | s = f"Warning: changing {name:{name_width}}" |
| | for k, v in kwargs.items(): |
| | s += f" {k}={v}" |
| | setattr(mod, k, v) |
| | logger.info(s) |
| |
|