gptq量化相关代码
#43
by
BigMaoGoGoGo
- opened
- gptq_quantization.py +324 -0
- modeling_chatglm.py +18 -5
- quantization.py +17 -3
gptq_quantization.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import transformers
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
LOGGER = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]
|
| 14 |
+
|
| 15 |
+
def is_transformer_conv1d(layer):
|
| 16 |
+
return isinstance(layer, transformers.Conv1D)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# These two functions only work on per-channel symmetric quantization for weight
|
| 20 |
+
def get_weight_scale(weight, weight_bit_width):
|
| 21 |
+
weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
| 22 |
+
return weight_scale
|
| 23 |
+
|
| 24 |
+
def fake_quantize_weight(weight, weight_scale):
|
| 25 |
+
weight_scale = weight_scale[:, None]
|
| 26 |
+
fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
|
| 27 |
+
return fake_quantized_weight
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GPTQLayerWrapper:
|
| 31 |
+
def __init__(self, layer_name, layer, weight_bit_width):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.layer_name = layer_name
|
| 34 |
+
self.layer = layer
|
| 35 |
+
self.device = layer.weight.device
|
| 36 |
+
columns = layer.weight.shape[1]
|
| 37 |
+
self.columns = columns
|
| 38 |
+
self.H = torch.zeros((columns, columns), device=self.device)
|
| 39 |
+
self.nsamples = 0
|
| 40 |
+
self.is_record = True
|
| 41 |
+
self.weight_bit_width = weight_bit_width
|
| 42 |
+
self.weight_scale = None
|
| 43 |
+
|
| 44 |
+
def record_h(self, x):
|
| 45 |
+
if self.is_record:
|
| 46 |
+
x = x.detach().clone()
|
| 47 |
+
if len(x.shape) == 2:
|
| 48 |
+
x = x.unsqueeze(0)
|
| 49 |
+
batch = x.shape[0]
|
| 50 |
+
if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
|
| 51 |
+
if len(x.shape) == 3:
|
| 52 |
+
x = x.reshape((-1, x.shape[-1]))
|
| 53 |
+
x = x.t()
|
| 54 |
+
|
| 55 |
+
if isinstance(self.layer, nn.Conv2d):
|
| 56 |
+
unfold = nn.Unfold(
|
| 57 |
+
self.layer.kernel_size,
|
| 58 |
+
dilation=self.layer.dilation,
|
| 59 |
+
padding=self.layer.padding,
|
| 60 |
+
stride=self.layer.stride
|
| 61 |
+
)
|
| 62 |
+
x = unfold(x)
|
| 63 |
+
x = x.permute([1, 0, 2])
|
| 64 |
+
x = x.flatten(1)
|
| 65 |
+
|
| 66 |
+
self.H *= self.nsamples / (self.nsamples + batch)
|
| 67 |
+
self.nsamples += batch
|
| 68 |
+
x = math.sqrt(2 / self.nsamples) * x.float()
|
| 69 |
+
self.H += x.matmul(x.t())
|
| 70 |
+
|
| 71 |
+
def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
|
| 72 |
+
if groupsize != -1:
|
| 73 |
+
raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
|
| 74 |
+
weight = self.layer.weight.data.clone()
|
| 75 |
+
if isinstance(self.layer, nn.Conv2d):
|
| 76 |
+
weight = weight.flatten(1)
|
| 77 |
+
if is_transformer_conv1d(self.layer):
|
| 78 |
+
weight = weight.t()
|
| 79 |
+
weight = weight.float()
|
| 80 |
+
|
| 81 |
+
weight_scale = get_weight_scale(weight, self.weight_bit_width)
|
| 82 |
+
# todo: use buffer to store scale
|
| 83 |
+
self.weight_scale = weight_scale
|
| 84 |
+
H = self.H
|
| 85 |
+
dead = torch.diag(H) == 0
|
| 86 |
+
H[dead, dead] = 1
|
| 87 |
+
weight[:, dead] = 0
|
| 88 |
+
|
| 89 |
+
losses = torch.zeros_like(weight)
|
| 90 |
+
Q = torch.zeros_like(weight)
|
| 91 |
+
|
| 92 |
+
damp = percdamp * torch.mean(torch.diag(H))
|
| 93 |
+
diag = torch.arange(self.columns, device=self.device)
|
| 94 |
+
H[diag, diag] += damp
|
| 95 |
+
try:
|
| 96 |
+
H = torch.linalg.cholesky(H)
|
| 97 |
+
H = torch.cholesky_inverse(H)
|
| 98 |
+
H = torch.linalg.cholesky(H, upper=True)
|
| 99 |
+
except Exception:
|
| 100 |
+
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
if H.isnan().any():
|
| 104 |
+
logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
hinv = H
|
| 108 |
+
|
| 109 |
+
for i1 in range(0, self.columns, blocksize):
|
| 110 |
+
i2 = min(i1 + blocksize, self.columns)
|
| 111 |
+
count = i2 - i1
|
| 112 |
+
|
| 113 |
+
w1 = weight[:, i1:i2].clone()
|
| 114 |
+
q1 = torch.zeros_like(w1)
|
| 115 |
+
total_err = torch.zeros_like(w1)
|
| 116 |
+
losses1 = torch.zeros_like(w1)
|
| 117 |
+
hinv1 = hinv[i1:i2, i1:i2]
|
| 118 |
+
|
| 119 |
+
for i in range(count):
|
| 120 |
+
w = w1[:, i]
|
| 121 |
+
d = hinv1[i, i]
|
| 122 |
+
|
| 123 |
+
q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()
|
| 124 |
+
|
| 125 |
+
q1[:, i] = q
|
| 126 |
+
losses1[:, i] = (w - q) ** 2 / d ** 2
|
| 127 |
+
err = (w - q) / d
|
| 128 |
+
w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
|
| 129 |
+
total_err[:, i] = err
|
| 130 |
+
|
| 131 |
+
Q[:, i1:i2] = q1
|
| 132 |
+
losses[:, i1:i2] = losses1 / 2
|
| 133 |
+
|
| 134 |
+
weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])
|
| 135 |
+
|
| 136 |
+
if torch.cuda.is_available():
|
| 137 |
+
torch.cuda.synchronize()
|
| 138 |
+
|
| 139 |
+
if is_transformer_conv1d(self.layer):
|
| 140 |
+
Q = Q.t()
|
| 141 |
+
shape = self.layer.weight.shape
|
| 142 |
+
dtype = self.layer.weight.data.dtype
|
| 143 |
+
del self.layer.weight
|
| 144 |
+
setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
|
| 145 |
+
del self.H
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class GPTQBlockWrapper:
|
| 149 |
+
def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
|
| 150 |
+
self.layer_wrappers = {}
|
| 151 |
+
self.hook_handles = []
|
| 152 |
+
# block order in the whole network
|
| 153 |
+
self.order = 0
|
| 154 |
+
self.block_name = block_name
|
| 155 |
+
|
| 156 |
+
def get_hook(layer_name):
|
| 157 |
+
def record_hook(_, x):
|
| 158 |
+
self.layer_wrappers[layer_name].record_h(x[0])
|
| 159 |
+
return record_hook
|
| 160 |
+
|
| 161 |
+
for layer_name, layer in block.named_modules():
|
| 162 |
+
if isinstance(layer, tuple(QUANT_LAYERS)):
|
| 163 |
+
full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
|
| 164 |
+
self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
|
| 165 |
+
handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
|
| 166 |
+
self.hook_handles.append(handle)
|
| 167 |
+
|
| 168 |
+
def quant_block(self):
|
| 169 |
+
for _, wrapper in self.layer_wrappers.items():
|
| 170 |
+
wrapper.quant_weight()
|
| 171 |
+
|
| 172 |
+
for h in self.hook_handles:
|
| 173 |
+
h.remove()
|
| 174 |
+
|
| 175 |
+
def set_order(self, idx):
|
| 176 |
+
self.order = idx
|
| 177 |
+
|
| 178 |
+
def get_order(self):
|
| 179 |
+
return self.order
|
| 180 |
+
|
| 181 |
+
def enable(self):
|
| 182 |
+
for n, l in self.layer_wrappers.items():
|
| 183 |
+
l.is_record = True
|
| 184 |
+
|
| 185 |
+
def disable(self):
|
| 186 |
+
for n, l in self.layer_wrappers.items():
|
| 187 |
+
l.is_record = False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class GPTQuantizer:
|
| 191 |
+
def __init__(self, block_type: Optional[List[type]] = None):
|
| 192 |
+
self.gptq_block_wrappers = {}
|
| 193 |
+
self.block_type = block_type
|
| 194 |
+
|
| 195 |
+
def wrap_model(self, model: nn.Module, weight_bit_width=8):
|
| 196 |
+
|
| 197 |
+
def wrap_block(m, prefix=""):
|
| 198 |
+
for name, child in m.named_children():
|
| 199 |
+
child_prefix = f"{prefix}.{name}" if prefix else name
|
| 200 |
+
if isinstance(child, tuple(self.block_type)):
|
| 201 |
+
self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
|
| 202 |
+
LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
|
| 203 |
+
else:
|
| 204 |
+
wrap_block(child, child_prefix)
|
| 205 |
+
|
| 206 |
+
wrap_block(model)
|
| 207 |
+
return model
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def calibration_iters(self):
|
| 211 |
+
return len(self.gptq_block_wrappers)
|
| 212 |
+
|
| 213 |
+
@contextlib.contextmanager
|
| 214 |
+
def record_order(self):
|
| 215 |
+
counter = 0
|
| 216 |
+
record_handles = []
|
| 217 |
+
orders = {}
|
| 218 |
+
try:
|
| 219 |
+
def get_record_order_hook(block_name):
|
| 220 |
+
def record_hook(*args, **kwargs):
|
| 221 |
+
nonlocal counter
|
| 222 |
+
if block_name not in orders:
|
| 223 |
+
orders[block_name] = counter
|
| 224 |
+
counter += 1
|
| 225 |
+
return record_hook
|
| 226 |
+
|
| 227 |
+
for block_name, block_wrapper in self.gptq_block_wrappers.items():
|
| 228 |
+
# disable the record
|
| 229 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
| 230 |
+
layer_wrapper.is_record = False
|
| 231 |
+
|
| 232 |
+
one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
|
| 233 |
+
handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
|
| 234 |
+
record_handles.append(handles)
|
| 235 |
+
yield
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logging.warning(e)
|
| 238 |
+
finally:
|
| 239 |
+
for block_name, order in orders.items():
|
| 240 |
+
self.gptq_block_wrappers[block_name].set_order(order)
|
| 241 |
+
|
| 242 |
+
for h in record_handles:
|
| 243 |
+
h.remove()
|
| 244 |
+
|
| 245 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
| 246 |
+
# disable the record
|
| 247 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
| 248 |
+
layer_wrapper.is_record = True
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@contextlib.contextmanager
|
| 252 |
+
def start_calib_iter(self, i):
|
| 253 |
+
assert i < len(self.gptq_block_wrappers)
|
| 254 |
+
target_block_wrapper = None
|
| 255 |
+
try:
|
| 256 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
| 257 |
+
if block_wrapper.get_order() == i:
|
| 258 |
+
block_wrapper.enable()
|
| 259 |
+
target_block_wrapper = block_wrapper
|
| 260 |
+
else:
|
| 261 |
+
block_wrapper.disable()
|
| 262 |
+
yield
|
| 263 |
+
finally:
|
| 264 |
+
target_block_wrapper.quant_block()
|
| 265 |
+
|
| 266 |
+
def release_reference(self):
|
| 267 |
+
# delete reference so that `torch.cuda.empty_cache()` can
|
| 268 |
+
# release all the gpu memory cache used during calibration
|
| 269 |
+
for _, block_wrapper in self.gptq_block_wrappers.items():
|
| 270 |
+
for _, layer_wrapper in block_wrapper.layer_wrappers.items():
|
| 271 |
+
del layer_wrapper.layer
|
| 272 |
+
|
| 273 |
+
torch.cuda.empty_cache()
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def locate_parent(root: nn.Module, full_path: str):
|
| 277 |
+
parent = root
|
| 278 |
+
path = full_path.split('.')
|
| 279 |
+
for p in path[:-1]:
|
| 280 |
+
parent = getattr(parent, p)
|
| 281 |
+
return parent, path[-1]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@torch.no_grad()
|
| 285 |
+
def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
|
| 286 |
+
from .modeling_chatglm import GLMBlock
|
| 287 |
+
from .quantization import QuantizedLinear
|
| 288 |
+
|
| 289 |
+
quantizer = GPTQuantizer([GLMBlock])
|
| 290 |
+
calib_model = quantizer.wrap_model(model, weight_bit_width)
|
| 291 |
+
with quantizer.record_order():
|
| 292 |
+
calib_model.chat(tokenizer, calib_data[0], history=[])
|
| 293 |
+
|
| 294 |
+
logging.info("Start doing calibration using GPTQ ")
|
| 295 |
+
for i in range(quantizer.calibration_iters):
|
| 296 |
+
logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
|
| 297 |
+
# todo: should add early return to speed up the calibration
|
| 298 |
+
# todo: add cpu offload to reduce the gpu memory requirements.
|
| 299 |
+
with quantizer.start_calib_iter(i):
|
| 300 |
+
for prompt in calib_data:
|
| 301 |
+
model.chat(tokenizer, prompt, history=[])
|
| 302 |
+
|
| 303 |
+
# replace the fp16 linear with quantized linear
|
| 304 |
+
for _, block_wrapper in quantizer.gptq_block_wrappers.items():
|
| 305 |
+
for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
|
| 306 |
+
layer = layer_wrapper.layer
|
| 307 |
+
parent, name_in_parent = locate_parent(model, layer_name)
|
| 308 |
+
quantized_layer = QuantizedLinear(
|
| 309 |
+
weight_bit_width=weight_bit_width,
|
| 310 |
+
weight_tensor=layer.weight,
|
| 311 |
+
bias_tensor=layer.bias,
|
| 312 |
+
weight_scale=layer_wrapper.weight_scale,
|
| 313 |
+
in_features=layer.in_features,
|
| 314 |
+
out_features=layer.out_features,
|
| 315 |
+
bias=True,
|
| 316 |
+
dtype=torch.half,
|
| 317 |
+
device=layer_wrapper.device,
|
| 318 |
+
empty_init=False
|
| 319 |
+
)
|
| 320 |
+
parent.add_module(name_in_parent, quantized_layer)
|
| 321 |
+
|
| 322 |
+
# release the memory caache during calibration
|
| 323 |
+
quantizer.release_reference()
|
| 324 |
+
return
|
modeling_chatglm.py
CHANGED
|
@@ -1408,12 +1408,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1408 |
break
|
| 1409 |
yield input_ids
|
| 1410 |
|
| 1411 |
-
def quantize(
|
|
|
|
|
|
|
| 1412 |
if bits == 0:
|
| 1413 |
return
|
| 1414 |
|
| 1415 |
-
from .quantization import quantize
|
| 1416 |
-
|
| 1417 |
if self.quantized:
|
| 1418 |
logger.info("Already quantized.")
|
| 1419 |
return self
|
|
@@ -1421,6 +1423,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1421 |
self.quantized = True
|
| 1422 |
|
| 1423 |
self.config.quantization_bit = bits
|
| 1424 |
-
|
| 1425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1426 |
return self
|
|
|
|
| 1408 |
break
|
| 1409 |
yield input_ids
|
| 1410 |
|
| 1411 |
+
def quantize(
|
| 1412 |
+
self, bits: int, empty_init=False, quant_algo_type: str="min_max",
|
| 1413 |
+
calib_data: Optional[List[str]]=None, tokenizer=None, **kwargs):
|
| 1414 |
if bits == 0:
|
| 1415 |
return
|
| 1416 |
|
| 1417 |
+
from .quantization import quantize, QuantAlgoType
|
| 1418 |
+
from .gptq_quantization import gptq_quantize
|
| 1419 |
if self.quantized:
|
| 1420 |
logger.info("Already quantized.")
|
| 1421 |
return self
|
|
|
|
| 1423 |
self.quantized = True
|
| 1424 |
|
| 1425 |
self.config.quantization_bit = bits
|
| 1426 |
+
quant_algo_type = QuantAlgoType(quant_algo_type)
|
| 1427 |
+
if quant_algo_type == QuantAlgoType.min_max:
|
| 1428 |
+
self.transformer = quantize(
|
| 1429 |
+
self.transformer, bits, empty_init=empty_init, algo_type=quant_algo_type, calib_data=calib_data, tokenizer=tokenizer, **kwargs)
|
| 1430 |
+
elif quant_algo_type == QuantAlgoType.gptq:
|
| 1431 |
+
if calib_data is None or tokenizer is None:
|
| 1432 |
+
raise RuntimeError("If using gptq to quantize the model, "
|
| 1433 |
+
"calibration data (e.g. some string prompts) and tokenizer should be provided")
|
| 1434 |
+
gptq_quantize(
|
| 1435 |
+
self, tokenizer, bits, calib_data
|
| 1436 |
+
)
|
| 1437 |
+
else:
|
| 1438 |
+
raise RuntimeError("Unsupported quantization algorithm type")
|
| 1439 |
return self
|
quantization.py
CHANGED
|
@@ -8,7 +8,7 @@ import ctypes
|
|
| 8 |
from transformers.utils import logging
|
| 9 |
|
| 10 |
from typing import List
|
| 11 |
-
from
|
| 12 |
|
| 13 |
logger = logging.get_logger(__name__)
|
| 14 |
|
|
@@ -41,6 +41,17 @@ except Exception as exception:
|
|
| 41 |
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
class W8A16Linear(torch.autograd.Function):
|
| 45 |
@staticmethod
|
| 46 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
|
@@ -118,7 +129,7 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
|
|
| 118 |
|
| 119 |
|
| 120 |
class QuantizedLinear(Linear):
|
| 121 |
-
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
|
| 122 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
| 123 |
self.weight_bit_width = weight_bit_width
|
| 124 |
|
|
@@ -131,7 +142,10 @@ class QuantizedLinear(Linear):
|
|
| 131 |
)
|
| 132 |
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
| 133 |
else:
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
| 136 |
if weight_bit_width == 4:
|
| 137 |
self.weight = compress_int4_weight(self.weight)
|
|
|
|
| 8 |
from transformers.utils import logging
|
| 9 |
|
| 10 |
from typing import List
|
| 11 |
+
from enum import Enum
|
| 12 |
|
| 13 |
logger = logging.get_logger(__name__)
|
| 14 |
|
|
|
|
| 41 |
logger.warning("Failed to load cpm_kernels:" + str(exception))
|
| 42 |
|
| 43 |
|
| 44 |
+
class QuantAlgoType(Enum):
|
| 45 |
+
min_max = 'min_max'
|
| 46 |
+
gptq = 'gptq'
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def _missing_(cls, value):
|
| 50 |
+
supported_types = [e.value for e in cls]
|
| 51 |
+
raise ValueError(f"Unsupported quantization algorithm type. Support list: "
|
| 52 |
+
f"{supported_types}. Got: '{value}'")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
class W8A16Linear(torch.autograd.Function):
|
| 56 |
@staticmethod
|
| 57 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
class QuantizedLinear(Linear):
|
| 132 |
+
def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, weight_scale=None, empty_init=False, *args, **kwargs):
|
| 133 |
super(QuantizedLinear, self).__init__(*args, **kwargs)
|
| 134 |
self.weight_bit_width = weight_bit_width
|
| 135 |
|
|
|
|
| 142 |
)
|
| 143 |
self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
|
| 144 |
else:
|
| 145 |
+
if weight_scale is None:
|
| 146 |
+
self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
|
| 147 |
+
else:
|
| 148 |
+
self.weight_scale = weight_scale
|
| 149 |
self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
|
| 150 |
if weight_bit_width == 4:
|
| 151 |
self.weight = compress_int4_weight(self.weight)
|