| | import dataclasses |
| | import gc |
| | import glob |
| | import os |
| |
|
| | from accelerate import init_empty_weights |
| | from accelerate.utils import set_module_tensor_to_device |
| | import torch |
| | from torch import Tensor |
| | from torch.nn import functional as F |
| | import torch.nn as nn |
| | from tqdm import tqdm |
| | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class CompressionConfig: |
| | """Group-wise quantization.""" |
| |
|
| | num_bits: int |
| | group_size: int |
| | group_dim: int |
| | symmetric: bool |
| | enabled: bool = True |
| |
|
| |
|
| | default_compression_config = CompressionConfig( |
| | num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True |
| | ) |
| |
|
| |
|
| | class CLinear(nn.Module): |
| | """Compressed Linear Layer.""" |
| |
|
| | def __init__(self, weight=None, bias=None, device=None): |
| | super().__init__() |
| | if weight is None: |
| | self.weight = None |
| | elif isinstance(weight, Tensor): |
| | self.weight = compress(weight.data.to(device), default_compression_config) |
| | else: |
| | self.weight = weight |
| | self.bias = bias |
| |
|
| | def forward(self, input: Tensor) -> Tensor: |
| | weight = decompress(self.weight, default_compression_config) |
| | if self.bias is None: |
| | return F.linear(input.to(weight.dtype), weight) |
| | return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) |
| |
|
| |
|
| | def compress_module(module, target_device): |
| | for attr_str in dir(module): |
| | target_attr = getattr(module, attr_str) |
| | if type(target_attr) == torch.nn.Linear: |
| | setattr( |
| | module, |
| | attr_str, |
| | CLinear(target_attr.weight, target_attr.bias, target_device), |
| | ) |
| | for name, child in module.named_children(): |
| | compress_module(child, target_device) |
| |
|
| |
|
| | def get_compressed_list(module, prefix=""): |
| | compressed_list = [] |
| | for attr_str in dir(module): |
| | target_attr = getattr(module, attr_str) |
| | if type(target_attr) == torch.nn.Linear: |
| | full_name = ( |
| | f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" |
| | ) |
| | compressed_list.append(full_name) |
| | for name, child in module.named_children(): |
| | child_prefix = f"{prefix}.{name}" if prefix else name |
| | for each in get_compressed_list(child, child_prefix): |
| | compressed_list.append(each) |
| | return compressed_list |
| |
|
| |
|
| | def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): |
| | for attr_str in dir(module): |
| | target_attr = getattr(module, attr_str) |
| | if type(target_attr) == torch.nn.Linear: |
| | full_name = ( |
| | f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" |
| | ) |
| | setattr( |
| | module, |
| | attr_str, |
| | CLinear( |
| | compressed_state_dict[full_name], target_attr.bias, target_device |
| | ), |
| | ) |
| | for name, child in module.named_children(): |
| | child_prefix = f"{prefix}.{name}" if prefix else name |
| | apply_compressed_weight( |
| | child, compressed_state_dict, target_device, child_prefix |
| | ) |
| |
|
| |
|
| | def load_compress_model(model_path, device, torch_dtype, use_fast=False): |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast) |
| | base_pattern = os.path.join(model_path, "pytorch_model*.bin") |
| | files = glob.glob(base_pattern) |
| |
|
| | with init_empty_weights(): |
| | config = AutoConfig.from_pretrained( |
| | model_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype |
| | ) |
| | model = AutoModelForCausalLM.from_config(config) |
| | linear_weights = get_compressed_list(model) |
| |
|
| | compressed_state_dict = {} |
| |
|
| | for filename in tqdm(files): |
| | tmp_state_dict = torch.load(filename) |
| | for name in tmp_state_dict: |
| | if name in linear_weights: |
| | tensor = tmp_state_dict[name].to(device).data.to(torch_dtype) |
| | compressed_state_dict[name] = compress( |
| | tensor, default_compression_config |
| | ) |
| | else: |
| | compressed_state_dict[name] = tmp_state_dict[name].to(device) |
| | tmp_state_dict[name] = None |
| | tensor = None |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | for name in model.state_dict(): |
| | if name not in linear_weights: |
| | set_module_tensor_to_device( |
| | model, name, device, value=compressed_state_dict[name] |
| | ) |
| | apply_compressed_weight(model, compressed_state_dict, device) |
| |
|
| | model.to(device) |
| |
|
| | return model, tokenizer |
| |
|
| |
|
| | def compress(tensor, config): |
| | """Simulate group-wise quantization.""" |
| | if not config.enabled: |
| | return tensor |
| |
|
| | group_size, num_bits, group_dim, symmetric = ( |
| | config.group_size, |
| | config.num_bits, |
| | config.group_dim, |
| | config.symmetric, |
| | ) |
| | assert num_bits <= 8 |
| |
|
| | original_shape = tensor.shape |
| | num_groups = (original_shape[group_dim] + group_size - 1) // group_size |
| | new_shape = ( |
| | original_shape[:group_dim] |
| | + (num_groups, group_size) |
| | + original_shape[group_dim + 1 :] |
| | ) |
| |
|
| | |
| | pad_len = (group_size - original_shape[group_dim] % group_size) % group_size |
| | if pad_len != 0: |
| | pad_shape = ( |
| | original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] |
| | ) |
| | tensor = torch.cat( |
| | [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], |
| | dim=group_dim, |
| | ) |
| | data = tensor.view(new_shape) |
| |
|
| | |
| | if symmetric: |
| | B = 2 ** (num_bits - 1) - 1 |
| | scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] |
| | data = data * scale |
| | data = data.clamp_(-B, B).round_().to(torch.int8) |
| | return data, scale, original_shape |
| | else: |
| | B = 2**num_bits - 1 |
| | mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] |
| | mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] |
| |
|
| | scale = B / (mx - mn) |
| | data = data - mn |
| | data.mul_(scale) |
| |
|
| | data = data.clamp_(0, B).round_().to(torch.uint8) |
| | return data, mn, scale, original_shape |
| |
|
| |
|
| | def decompress(packed_data, config): |
| | """Simulate group-wise dequantization.""" |
| | if not config.enabled: |
| | return packed_data |
| |
|
| | group_size, num_bits, group_dim, symmetric = ( |
| | config.group_size, |
| | config.num_bits, |
| | config.group_dim, |
| | config.symmetric, |
| | ) |
| |
|
| | |
| | if symmetric: |
| | data, scale, original_shape = packed_data |
| | data = data / scale |
| | else: |
| | data, mn, scale, original_shape = packed_data |
| | data = data / scale |
| | data.add_(mn) |
| |
|
| | |
| | pad_len = (group_size - original_shape[group_dim] % group_size) % group_size |
| | if pad_len: |
| | padded_original_shape = ( |
| | original_shape[:group_dim] |
| | + (original_shape[group_dim] + pad_len,) |
| | + original_shape[group_dim + 1 :] |
| | ) |
| | data = data.reshape(padded_original_shape) |
| | indices = [slice(0, x) for x in original_shape] |
| | return data[indices].contiguous() |
| | else: |
| | return data.view(original_shape) |