|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
try: |
|
|
import torch.ao.quantization as quantization |
|
|
except ImportError: |
|
|
import torch.quantization as quantization |
|
|
|
|
|
|
|
|
def emulate_int(w, bits, method, scale=None, zero_point=None): |
|
|
q = globals()[f"emulate_int8_{method}"] |
|
|
return q(w, scale=scale, zero_point=zero_point, bits=bits) |
|
|
|
|
|
|
|
|
def quantize(w, scale, zero_point, bits=8): |
|
|
|
|
|
max_val = 2**bits - 1 |
|
|
return ( |
|
|
torch.clamp(torch.round(w / scale + zero_point), 0, max_val) - zero_point |
|
|
) * scale |
|
|
|
|
|
|
|
|
def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8): |
|
|
if scale is None: |
|
|
obs = quantization.observer.HistogramObserver() |
|
|
obs.to(device=w.device) |
|
|
_ = obs(w.float()) |
|
|
scale, zero_point = obs.calculate_qparams() |
|
|
scale = scale.cuda().type_as(w) |
|
|
zero_point = zero_point.cuda().type_as(w) |
|
|
return quantize(w, scale, zero_point, bits=bits), scale, zero_point |
|
|
|
|
|
|
|
|
def emulate_int8_channel(w, scale=None, zero_point=None, bits=8): |
|
|
if scale is None: |
|
|
obs = quantization.observer.PerChannelMinMaxObserver( |
|
|
ch_axis=-1, qscheme=torch.per_channel_symmetric |
|
|
) |
|
|
obs.to(device=w.device) |
|
|
_ = obs(w) |
|
|
scale, zero_point, ch_axis = obs.get_qparams() |
|
|
scale = scale.cuda().type_as(w) |
|
|
zero_point = zero_point.cuda().type_as(w) |
|
|
return quantize(w, scale, zero_point, bits=bits), scale, zero_point |
|
|
|
|
|
|
|
|
def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8): |
|
|
if scale is None: |
|
|
obs = quantization.observer.MinMaxObserver() |
|
|
obs.to(device=w.device) |
|
|
_ = obs(w) |
|
|
scale, zero_point = obs.calculate_qparams() |
|
|
scale = scale.cuda().type_as(w) |
|
|
zero_point = zero_point.cuda().type_as(w) |
|
|
return quantize(w, scale, zero_point, bits=bits), scale, zero_point |
|
|
|