|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from operator import attrgetter |
|
|
|
|
|
import torch.nn as nn |
|
|
import torch.distributed as dist |
|
|
|
|
|
from ..pq.utils import get_layers, attrsetter |
|
|
from .modules import IntConv2d, IntLinear, IntEmbedding, ActivationQuantizer |
|
|
|
|
|
|
|
|
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} |
|
|
|
|
|
|
|
|
def quantize_model_(model, p=0.2, bits=8, update_step=3000): |
|
|
""" |
|
|
Replaces all modules with their scalar quantized counterpart and |
|
|
registers hooks to quantize the post-ativations of those modules. |
|
|
|
|
|
Args: |
|
|
- model: a nn.Module |
|
|
- p: amount of noise (0 for no noise, 1 to quantize all the weights/activations) |
|
|
- bits: number of bits |
|
|
- update_step: update quantization parameters every update_step steps |
|
|
""" |
|
|
|
|
|
|
|
|
quantized_layers = get_layers(model, "(.*?)") |
|
|
|
|
|
for layer in quantized_layers: |
|
|
|
|
|
|
|
|
is_master_process = (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0) |
|
|
|
|
|
|
|
|
module = attrgetter(layer)(model) |
|
|
if is_master_process: |
|
|
logging.info(f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}") |
|
|
|
|
|
|
|
|
q_params = {"p": p, "update_step": update_step, "bits": bits, "method": "histogram", "counter": 0} |
|
|
|
|
|
|
|
|
if isinstance(module, tuple(MAPPING.keys())): |
|
|
QuantizedModule = MAPPING[module.__class__] |
|
|
quantized_module = QuantizedModule.__new__(QuantizedModule) |
|
|
params = module.__dict__ |
|
|
params.update(q_params) |
|
|
quantized_module.__dict__.update(params) |
|
|
|
|
|
else: |
|
|
if is_master_process: |
|
|
logging.info(f"Module {module} not yet supported for quantization") |
|
|
continue |
|
|
|
|
|
|
|
|
a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram") |
|
|
|
|
|
|
|
|
attrsetter(layer)(model, quantized_module) |
|
|
|
|
|
|
|
|
return quantized_layers |
|
|
|