julse's picture
Upload 551 files
be611b4 verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from operator import attrgetter
import torch.nn as nn
import torch.distributed as dist
from ..pq.utils import get_layers, attrsetter
from .modules import IntConv2d, IntLinear, IntEmbedding, ActivationQuantizer
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
def quantize_model_(model, p=0.2, bits=8, update_step=3000):
"""
Replaces all modules with their scalar quantized counterpart and
registers hooks to quantize the post-ativations of those modules.
Args:
- model: a nn.Module
- p: amount of noise (0 for no noise, 1 to quantize all the weights/activations)
- bits: number of bits
- update_step: update quantization parameters every update_step steps
"""
# quantize all layers
quantized_layers = get_layers(model, "(.*?)")
for layer in quantized_layers:
# book-keeping
is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0)
# recover module
module = attrgetter(layer)(model)
if is_master_process:
logging.info(f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}")
# quantization params
q_params = {"p": p, "update_step": update_step, "bits": bits, "method": "histogram", "counter": 0}
# instantiate the quantized counterpart
if isinstance(module, tuple(MAPPING.keys())):
QuantizedModule = MAPPING[module.__class__]
quantized_module = QuantizedModule.__new__(QuantizedModule)
params = module.__dict__
params.update(q_params)
quantized_module.__dict__.update(params)
else:
if is_master_process:
logging.info(f"Module {module} not yet supported for quantization")
continue
# activation quantization
a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram")
# replace layer by its quantized counterpart
attrsetter(layer)(model, quantized_module)
# return name of quantized layers
return quantized_layers