| import json |
| import shlex |
| import subprocess |
| from typing import Tuple |
|
|
| import torch |
|
|
|
|
| def outlier_hook(module, input): |
| assert isinstance(module, torch.nn.Linear) |
| tracer = OutlierTracer.get_instance() |
| hvalue = tracer.get_hvalue(module.weight) |
| if hvalue not in tracer.hvalue2outlier_idx: |
| outlier_idx = find_outlier_dims(module.weight) |
| tracer.outliers.append(outlier_idx) |
| tracer.hvalues.append(hvalue) |
| if len(tracer.outliers) > 1: |
| |
| |
| if tracer.outliers[-1].numel() > 0: |
| assert tracer.outliers[-1].max() < module.weight.shape[1] |
| tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] |
|
|
| else: |
| |
| |
| |
| |
| merged = input[0].view(-1, input[0].shape[-1]) |
| |
| outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) |
| |
| dims = (torch.abs(input[0]) > 6).sum(dim=list(range(len(input[0].shape) - 1))) |
| outlier_idx2 = torch.where(dims > 0)[0] |
| outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() |
| tracer.hvalue2outlier_idx[hvalue] = outlier_idx |
| else: |
| for hook in tracer.hooks: |
| hook.remove() |
|
|
|
|
| class OutlierTracer: |
| _instance = None |
|
|
| def __init__(self): |
| raise RuntimeError("Call get_instance() instead") |
|
|
| def initialize(self, model): |
| self.last_w = None |
| self.current_outlier_dims = None |
| self.hvalues = [] |
| self.outliers = [] |
| self.hvalue2outlier_idx = {} |
| self.initialized = True |
| self.hooks = [] |
|
|
| for n, m in model.named_modules(): |
| if isinstance(m, torch.nn.Linear): |
| self.hooks.append(m.register_forward_pre_hook(outlier_hook)) |
|
|
| def is_initialized(self): |
| return getattr(self, "initialized", False) |
|
|
| def get_hvalue(self, weight): |
| return weight.data.storage().data_ptr() |
|
|
| def get_outliers(self, weight): |
| if not self.is_initialized(): |
| print("Outlier tracer is not initialized...") |
| return None |
| hvalue = self.get_hvalue(weight) |
| if hvalue in self.hvalue2outlier_idx: |
| return self.hvalue2outlier_idx[hvalue] |
| else: |
| return None |
|
|
| @classmethod |
| def get_instance(cls): |
| if cls._instance is None: |
| cls._instance = cls.__new__(cls) |
| return cls._instance |
|
|
|
|
| def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): |
| if rdm: |
| return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() |
|
|
| m = weight.mean(reduction_dim) |
| mm = m.mean() |
| mstd = m.std() |
| zm = (m - mm) / mstd |
|
|
| std = weight.std(reduction_dim) |
| stdm = std.mean() |
| stdstd = std.std() |
|
|
| zstd = (std - stdm) / stdstd |
|
|
| if topk is not None: |
| val, idx = torch.topk(std.abs(), k=topk, dim=0) |
| else: |
| idx = torch.where(zstd > zscore)[0] |
|
|
| return idx |
|
|
|
|
| def execute_and_return(command_string: str) -> Tuple[str, str]: |
| def _decode(subprocess_err_out_tuple): |
| return tuple(to_decode.decode("UTF-8").strip() for to_decode in subprocess_err_out_tuple) |
|
|
| def execute_and_return_decoded_std_streams(command_string): |
| return _decode( |
| subprocess.Popen( |
| shlex.split(command_string), |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| ).communicate(), |
| ) |
|
|
| std_out, std_err = execute_and_return_decoded_std_streams(command_string) |
| return std_out, std_err |
|
|
|
|
| def replace_linear( |
| model, |
| linear_replacement, |
| skip_modules=("lm_head",), |
| copy_weights=False, |
| post_processing_function=None, |
| ): |
| """ |
| Replace linear modules with a new Linear module. |
| Parameters: |
| model (`torch.nn.Module`): |
| Input model or `torch.nn.Module` as the function is run recursively. |
| linear_replacement (`torch.nn.Module`): |
| The linear module that replaces the old one. Only expects standard arguments. |
| If other arguments need to be passed, use a lambda. |
| skip_modules (`List[str]`, *optional*, defaults to `lm_head`): |
| List of modules names not to convert. Defaults to `lm_head`. |
| copy_weights (`bool`): |
| Copy the weights from the old linear module to the new one |
| post_processing_function (`str`): |
| A function name of the replacement linear class that is called |
| after processing. |
| """ |
| for name, module in model.named_children(): |
| if len(list(module.children())) > 0: |
| replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) |
|
|
| if isinstance(module, torch.nn.Linear) and name not in skip_modules: |
| old_module = model._modules[name] |
| model._modules[name] = linear_replacement( |
| module.in_features, |
| module.out_features, |
| module.bias is not None, |
| ) |
| if copy_weights: |
| model._modules[name].weight = old_module.weight |
| model._modules[name].bias = old_module.bias |
|
|
| if post_processing_function is not None: |
| func = getattr(module, post_processing_function, None) |
| if func is not None: |
| func(module) |
| return model |
|
|
|
|
| def pack_dict_to_tensor(source_dict): |
| """ |
| Pack a dictionary into a torch tensor for storing quant_state items in state_dict. |
| |
| Parameters: |
| - source_dict: The dictionary to be packed. |
| |
| Returns: |
| A torch tensor containing the packed data. |
| """ |
| json_str = json.dumps(source_dict) |
| json_bytes = json_str.encode("utf-8") |
| tensor_data = torch.tensor(list(json_bytes), dtype=torch.uint8) |
|
|
| return tensor_data |
|
|
|
|
| def unpack_tensor_to_dict(tensor_data): |
| """ |
| Unpack a torch tensor into a Python dictionary. |
| |
| Parameters: |
| - tensor_data: The torch tensor containing the packed data. |
| |
| Returns: |
| A Python dictionary containing the unpacked data. |
| """ |
| json_bytes = bytes(tensor_data.cpu().numpy()) |
| json_str = json_bytes.decode("utf-8") |
| unpacked_dict = json.loads(json_str) |
|
|
| return unpacked_dict |
|
|
|
|
| LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} |
| INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} |
|
|