Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import types | |
| import logging | |
| import torch | |
| from . import default_layers | |
| _logger = logging.getLogger(__name__) | |
| class LayerInfo: | |
| def __init__(self, name, module): | |
| self.module = module | |
| self.name = name | |
| self.type = type(module).__name__ | |
| def _setattr(model, name, module): | |
| name_list = name.split(".") | |
| for name in name_list[:-1]: | |
| model = getattr(model, name) | |
| setattr(model, name_list[-1], module) | |
| class Compressor: | |
| """ | |
| Abstract base PyTorch compressor | |
| """ | |
| def __init__(self, model, config_list, optimizer=None): | |
| """ | |
| Record necessary info in class members | |
| Parameters | |
| ---------- | |
| model : pytorch model | |
| the model user wants to compress | |
| config_list : list | |
| the configurations that users specify for compression | |
| optimizer: pytorch optimizer | |
| optimizer used to train the model | |
| """ | |
| assert isinstance(model, torch.nn.Module) | |
| self.validate_config(model, config_list) | |
| self.bound_model = model | |
| self.config_list = config_list | |
| self.optimizer = optimizer | |
| self.modules_to_compress = None | |
| self.modules_wrapper = [] | |
| self.is_wrapped = False | |
| self._fwd_hook_handles = {} | |
| self._fwd_hook_id = 0 | |
| self.reset() | |
| if not self.modules_wrapper: | |
| _logger.warning('Nothing is configured to compress, please check your model and config_list') | |
| def validate_config(self, model, config_list): | |
| """ | |
| subclass can optionally implement this method to check if config_list if valid | |
| """ | |
| pass | |
| def reset(self, checkpoint=None): | |
| """ | |
| reset model state dict and model wrapper | |
| """ | |
| self._unwrap_model() | |
| if checkpoint is not None: | |
| self.bound_model.load_state_dict(checkpoint) | |
| self.modules_to_compress = None | |
| self.modules_wrapper = [] | |
| for layer, config in self._detect_modules_to_compress(): | |
| wrapper = self._wrap_modules(layer, config) | |
| self.modules_wrapper.append(wrapper) | |
| self._wrap_model() | |
| def _detect_modules_to_compress(self): | |
| """ | |
| detect all modules should be compressed, and save the result in `self.modules_to_compress`. | |
| The model will be instrumented and user should never edit it after calling this method. | |
| """ | |
| if self.modules_to_compress is None: | |
| self.modules_to_compress = [] | |
| for name, module in self.bound_model.named_modules(): | |
| if module == self.bound_model: | |
| continue | |
| layer = LayerInfo(name, module) | |
| config = self.select_config(layer) | |
| if config is not None: | |
| self.modules_to_compress.append((layer, config)) | |
| return self.modules_to_compress | |
| def _wrap_model(self): | |
| """ | |
| wrap all modules that needed to be compressed | |
| """ | |
| for wrapper in reversed(self.get_modules_wrapper()): | |
| _setattr(self.bound_model, wrapper.name, wrapper) | |
| self.is_wrapped = True | |
| def _unwrap_model(self): | |
| """ | |
| unwrap all modules that needed to be compressed | |
| """ | |
| for wrapper in self.get_modules_wrapper(): | |
| _setattr(self.bound_model, wrapper.name, wrapper.module) | |
| self.is_wrapped = False | |
| def compress(self): | |
| """ | |
| Compress the model with algorithm implemented by subclass. | |
| The model will be instrumented and user should never edit it after calling this method. | |
| `self.modules_to_compress` records all the to-be-compressed layers | |
| Returns | |
| ------- | |
| torch.nn.Module | |
| model with specified modules compressed. | |
| """ | |
| return self.bound_model | |
| def set_wrappers_attribute(self, name, value): | |
| """ | |
| To register attributes used in wrapped module's forward method. | |
| If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper, | |
| which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper. | |
| Parameters | |
| ---------- | |
| name : str | |
| name of the variable | |
| value: any | |
| value of the variable | |
| """ | |
| for wrapper in self.get_modules_wrapper(): | |
| if isinstance(value, torch.Tensor): | |
| wrapper.register_buffer(name, value.clone()) | |
| else: | |
| setattr(wrapper, name, value) | |
| def get_modules_to_compress(self): | |
| """ | |
| To obtain all the to-be-compressed modules. | |
| Returns | |
| ------- | |
| list | |
| a list of the layers, each of which is a tuple (`layer`, `config`), | |
| `layer` is `LayerInfo`, `config` is a `dict` | |
| """ | |
| return self.modules_to_compress | |
| def get_modules_wrapper(self): | |
| """ | |
| To obtain all the wrapped modules. | |
| Returns | |
| ------- | |
| list | |
| a list of the wrapped modules | |
| """ | |
| return self.modules_wrapper | |
| def select_config(self, layer): | |
| """ | |
| Find the configuration for `layer` by parsing `self.config_list` | |
| Parameters | |
| ---------- | |
| layer : LayerInfo | |
| one layer | |
| Returns | |
| ------- | |
| config or None | |
| the retrieved configuration for this layer, if None, this layer should | |
| not be compressed | |
| """ | |
| ret = None | |
| for config in self.config_list: | |
| config = config.copy() | |
| # expand config if key `default` is in config['op_types'] | |
| if 'op_types' in config and 'default' in config['op_types']: | |
| expanded_op_types = [] | |
| for op_type in config['op_types']: | |
| if op_type == 'default': | |
| expanded_op_types.extend(default_layers.weighted_modules) | |
| else: | |
| expanded_op_types.append(op_type) | |
| config['op_types'] = expanded_op_types | |
| # check if condition is satisified | |
| if 'op_types' in config and layer.type not in config['op_types']: | |
| continue | |
| if 'op_names' in config and layer.name not in config['op_names']: | |
| continue | |
| ret = config | |
| if ret is None or 'exclude' in ret: | |
| return None | |
| return ret | |
| def update_epoch(self, epoch): | |
| """ | |
| If user want to update model every epoch, user can override this method. | |
| This method should be called at the beginning of each epoch | |
| Parameters | |
| ---------- | |
| epoch : num | |
| the current epoch number | |
| """ | |
| pass | |
| def _wrap_modules(self, layer, config): | |
| """ | |
| This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` | |
| Parameters | |
| ---------- | |
| layer : LayerInfo | |
| the layer to instrument the compression operation | |
| config : dict | |
| the configuration for compressing this layer | |
| """ | |
| raise NotImplementedError() | |
| def add_activation_collector(self, collector): | |
| self._fwd_hook_id += 1 | |
| self._fwd_hook_handles[self._fwd_hook_id] = [] | |
| for wrapper in self.get_modules_wrapper(): | |
| handle = wrapper.register_forward_hook(collector) | |
| self._fwd_hook_handles[self._fwd_hook_id].append(handle) | |
| return self._fwd_hook_id | |
| def remove_activation_collector(self, fwd_hook_id): | |
| if fwd_hook_id not in self._fwd_hook_handles: | |
| raise ValueError("%s is not a valid collector id" % str(fwd_hook_id)) | |
| for handle in self._fwd_hook_handles[fwd_hook_id]: | |
| handle.remove() | |
| del self._fwd_hook_handles[fwd_hook_id] | |
| def patch_optimizer(self, *tasks): | |
| def patch_step(old_step): | |
| def new_step(_, *args, **kwargs): | |
| # call origin optimizer step method | |
| output = old_step(*args, **kwargs) | |
| # calculate mask | |
| for task in tasks: | |
| task() | |
| return output | |
| return new_step | |
| if self.optimizer is not None: | |
| self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer) | |
| class PrunerModuleWrapper(torch.nn.Module): | |
| def __init__(self, module, module_name, module_type, config, pruner): | |
| """ | |
| Wrap an module to enable data parallel, forward method customization and buffer registeration. | |
| Parameters | |
| ---------- | |
| module : pytorch module | |
| the module user wants to compress | |
| config : dict | |
| the configurations that users specify for compression | |
| module_name : str | |
| the name of the module to compress, wrapper module shares same name | |
| module_type : str | |
| the type of the module to compress | |
| pruner : Pruner | |
| the pruner used to calculate mask | |
| """ | |
| super().__init__() | |
| # origin layer information | |
| self.module = module | |
| self.name = module_name | |
| self.type = module_type | |
| # config and pruner | |
| self.config = config | |
| self.pruner = pruner | |
| # register buffer for mask | |
| self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) | |
| if hasattr(self.module, 'bias') and self.module.bias is not None: | |
| self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) | |
| else: | |
| self.register_buffer("bias_mask", None) | |
| def forward(self, *inputs): | |
| # apply mask to weight, bias | |
| self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) | |
| if hasattr(self.module, 'bias') and self.module.bias is not None: | |
| self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) | |
| return self.module(*inputs) | |
| class Pruner(Compressor): | |
| """ | |
| Prune to an exact pruning level specification | |
| Attributes | |
| ---------- | |
| mask_dict : dict | |
| Dictionary for saving masks, `key` should be layer name and | |
| `value` should be a tensor which has the same shape with layer's weight | |
| """ | |
| def __init__(self, model, config_list, optimizer=None): | |
| super().__init__(model, config_list, optimizer) | |
| if optimizer is not None: | |
| self.patch_optimizer(self.update_mask) | |
| def compress(self): | |
| self.update_mask() | |
| return self.bound_model | |
| def update_mask(self): | |
| for wrapper_idx, wrapper in enumerate(self.get_modules_wrapper()): | |
| masks = self.calc_mask(wrapper, wrapper_idx=wrapper_idx) | |
| if masks is not None: | |
| for k in masks: | |
| assert hasattr(wrapper, k), "there is no attribute '%s' in wrapper" % k | |
| setattr(wrapper, k, masks[k]) | |
| def calc_mask(self, wrapper, **kwargs): | |
| """ | |
| Pruners should overload this method to provide mask for weight tensors. | |
| The mask must have the same shape and type comparing to the weight. | |
| It will be applied with `mul()` operation on the weight. | |
| This method is effectively hooked to `forward()` method of the model. | |
| Parameters | |
| ---------- | |
| wrapper : Module | |
| calculate mask for `wrapper.module`'s weight | |
| """ | |
| raise NotImplementedError("Pruners must overload calc_mask()") | |
| def _wrap_modules(self, layer, config): | |
| """ | |
| Create a wrapper module to replace the original one. | |
| Parameters | |
| ---------- | |
| layer : LayerInfo | |
| the layer to instrument the mask | |
| config : dict | |
| the configuration for generating the mask | |
| """ | |
| _logger.debug("Module detected to compress : %s.", layer.name) | |
| wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) | |
| assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name | |
| # move newly registered buffers to the same device of weight | |
| wrapper.to(layer.module.weight.device) | |
| return wrapper | |
| def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): | |
| """ | |
| Export pruned model weights, masks and onnx model(optional) | |
| Parameters | |
| ---------- | |
| model_path : str | |
| path to save pruned model state_dict | |
| mask_path : str | |
| (optional) path to save mask dict | |
| onnx_path : str | |
| (optional) path to save onnx model | |
| input_shape : list or tuple | |
| input shape to onnx model | |
| device : torch.device | |
| device of the model, used to place the dummy input tensor for exporting onnx file. | |
| the tensor is placed on cpu if ```device``` is None | |
| """ | |
| assert model_path is not None, 'model_path must be specified' | |
| mask_dict = {} | |
| self._unwrap_model() # used for generating correct state_dict name without wrapper state | |
| for wrapper in self.get_modules_wrapper(): | |
| weight_mask = wrapper.weight_mask | |
| bias_mask = wrapper.bias_mask | |
| if weight_mask is not None: | |
| mask_sum = weight_mask.sum().item() | |
| mask_num = weight_mask.numel() | |
| _logger.debug('Layer: %s Sparsity: %.4f', wrapper.name, 1 - mask_sum / mask_num) | |
| wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask) | |
| if bias_mask is not None: | |
| wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask) | |
| # save mask to dict | |
| mask_dict[wrapper.name] = {"weight": weight_mask, "bias": bias_mask} | |
| torch.save(self.bound_model.state_dict(), model_path) | |
| _logger.info('Model state_dict saved to %s', model_path) | |
| if mask_path is not None: | |
| torch.save(mask_dict, mask_path) | |
| _logger.info('Mask dict saved to %s', mask_path) | |
| if onnx_path is not None: | |
| assert input_shape is not None, 'input_shape must be specified to export onnx model' | |
| # input info needed | |
| if device is None: | |
| device = torch.device('cpu') | |
| input_data = torch.Tensor(*input_shape) | |
| torch.onnx.export(self.bound_model, input_data.to(device), onnx_path) | |
| _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) | |
| self._wrap_model() | |
| def load_model_state_dict(self, model_state): | |
| """ | |
| Load the state dict saved from unwrapped model. | |
| Parameters: | |
| ----------- | |
| model_state : dict | |
| state dict saved from unwrapped model | |
| """ | |
| if self.is_wrapped: | |
| self._unwrap_model() | |
| self.bound_model.load_state_dict(model_state) | |
| self._wrap_model() | |
| else: | |
| self.bound_model.load_state_dict(model_state) | |
| class QuantizerModuleWrapper(torch.nn.Module): | |
| def __init__(self, module, module_name, module_type, config, quantizer): | |
| """ | |
| Wrap an module to enable data parallel, forward method customization and buffer registeration. | |
| Parameters | |
| ---------- | |
| module : pytorch module | |
| the module user wants to compress | |
| config : dict | |
| the configurations that users specify for compression | |
| module_name : str | |
| the name of the module to compress, wrapper module shares same name | |
| module_type : str | |
| the type of the module to compress | |
| quantizer :quantizer | |
| the quantizer used to calculate mask | |
| """ | |
| super().__init__() | |
| # origin layer information | |
| self.module = module | |
| self.name = module_name | |
| self.type = module_type | |
| # config and pruner | |
| self.config = config | |
| self.quantizer = quantizer | |
| # register buffer and parameter | |
| # old_weight is used to store origin weight and weight is used to store quantized weight | |
| # the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf | |
| # if weight is leaf , then old_weight can not be updated. | |
| if 'weight' in config['quant_types']: | |
| if not _check_weight(self.module): | |
| _logger.warning('Module %s does not have parameter "weight"', self.name) | |
| else: | |
| self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight)) | |
| delattr(self.module, 'weight') | |
| self.module.register_buffer('weight', self.module.old_weight) | |
| def forward(self, *inputs): | |
| if 'input' in self.config['quant_types']: | |
| inputs = self.quantizer.quant_grad.apply( | |
| inputs, | |
| QuantType.QUANT_INPUT, | |
| self) | |
| if 'weight' in self.config['quant_types'] and _check_weight(self.module): | |
| self.quantizer.quant_grad.apply( | |
| self.module.old_weight, | |
| QuantType.QUANT_WEIGHT, | |
| self) | |
| result = self.module(*inputs) | |
| else: | |
| result = self.module(*inputs) | |
| if 'output' in self.config['quant_types']: | |
| result = self.quantizer.quant_grad.apply( | |
| result, | |
| QuantType.QUANT_OUTPUT, | |
| self) | |
| return result | |
| return result | |
| class Quantizer(Compressor): | |
| """ | |
| Base quantizer for pytorch quantizer | |
| """ | |
| def __init__(self, model, config_list, optimizer=None): | |
| super().__init__(model, config_list, optimizer) | |
| self.quant_grad = QuantGrad | |
| if self.optimizer is not None: | |
| self.patch_optimizer(self.step_with_optimizer) | |
| for wrapper in self.get_modules_wrapper(): | |
| if 'weight' in wrapper.config['quant_types']: | |
| # old_weight is registered to keep track of weight before quantization | |
| # and it is trainable, therefore, it should be added to optimizer. | |
| self.optimizer.add_param_group({"params": wrapper.module.old_weight}) | |
| def quantize_weight(self, weight, wrapper, **kwargs): | |
| """ | |
| quantize should overload this method to quantize weight. | |
| This method is effectively hooked to :meth:`forward` of the model. | |
| Parameters | |
| ---------- | |
| weight : Tensor | |
| weight that needs to be quantized | |
| wrapper : QuantizerModuleWrapper | |
| the wrapper for origin module | |
| """ | |
| raise NotImplementedError('Quantizer must overload quantize_weight()') | |
| def quantize_output(self, output, wrapper, **kwargs): | |
| """ | |
| quantize should overload this method to quantize output. | |
| This method is effectively hooked to :meth:`forward` of the model. | |
| Parameters | |
| ---------- | |
| output : Tensor | |
| output that needs to be quantized | |
| wrapper : QuantizerModuleWrapper | |
| the wrapper for origin module | |
| """ | |
| raise NotImplementedError('Quantizer must overload quantize_output()') | |
| def quantize_input(self, *inputs, wrapper, **kwargs): | |
| """ | |
| quantize should overload this method to quantize input. | |
| This method is effectively hooked to :meth:`forward` of the model. | |
| Parameters | |
| ---------- | |
| inputs : Tensor | |
| inputs that needs to be quantized | |
| wrapper : QuantizerModuleWrapper | |
| the wrapper for origin module | |
| """ | |
| raise NotImplementedError('Quantizer must overload quantize_input()') | |
| def _wrap_modules(self, layer, config): | |
| """ | |
| Create a wrapper forward function to replace the original one. | |
| Parameters | |
| ---------- | |
| layer : LayerInfo | |
| the layer to instrument the mask | |
| config : dict | |
| the configuration for quantization | |
| """ | |
| assert 'quant_types' in config, 'must provide quant_types in config' | |
| assert isinstance(config['quant_types'], list), 'quant_types must be list type' | |
| assert 'quant_bits' in config, 'must provide quant_bits in config' | |
| assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type' | |
| if isinstance(config['quant_bits'], dict): | |
| for quant_type in config['quant_types']: | |
| assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type | |
| return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) | |
| def step_with_optimizer(self): | |
| pass | |
| class QuantType: | |
| """ | |
| Enum class for quantization type. | |
| """ | |
| QUANT_INPUT = 0 | |
| QUANT_WEIGHT = 1 | |
| QUANT_OUTPUT = 2 | |
| class QuantGrad(torch.autograd.Function): | |
| """ | |
| Base class for overriding backward function of quantization operation. | |
| """ | |
| def quant_backward(tensor, grad_output, quant_type): | |
| """ | |
| This method should be overrided by subclass to provide customized backward function, | |
| default implementation is Straight-Through Estimator | |
| Parameters | |
| ---------- | |
| tensor : Tensor | |
| input of quantization operation | |
| grad_output : Tensor | |
| gradient of the output of quantization operation | |
| quant_type : QuantType | |
| the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, | |
| you can define different behavior for different types. | |
| Returns | |
| ------- | |
| tensor | |
| gradient of the input of quantization operation | |
| """ | |
| return grad_output | |
| def forward(ctx, tensor, quant_type, wrapper, **kwargs): | |
| ctx.save_for_backward(tensor, torch.Tensor([quant_type])) | |
| if quant_type == QuantType.QUANT_INPUT: | |
| return wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) | |
| elif quant_type == QuantType.QUANT_WEIGHT: | |
| return wrapper.quantizer.quantize_weight(wrapper, **kwargs) | |
| elif quant_type == QuantType.QUANT_OUTPUT: | |
| return wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) | |
| else: | |
| raise ValueError("unrecognized QuantType.") | |
| def backward(cls, ctx, grad_output): | |
| tensor, quant_type = ctx.saved_variables | |
| output = cls.quant_backward(tensor, grad_output, quant_type) | |
| return output, None, None, None | |
| def _check_weight(module): | |
| try: | |
| return isinstance(module.weight.data, torch.Tensor) | |
| except AttributeError: | |
| return False | |