Spaces:
Running
Running
| import copy | |
| import os | |
| import sys | |
| from tabnanny import verbose | |
| from typing import List, Optional, Tuple | |
| import torch | |
| from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner | |
| from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
| from ...common.others import get_cur_time_str | |
| def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False): | |
| pruner.compress() | |
| pid = os.getpid() | |
| timestamp = get_cur_time_str() | |
| tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp) | |
| tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp) | |
| pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path) | |
| os.remove(tmp_model_path) | |
| # speed up | |
| dummy_input = torch.rand(model_input_size).to(device) | |
| pruned_model = model | |
| pruned_model.eval() | |
| model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device) | |
| fixed_mask = model_speedup.speedup_model() | |
| if not need_return_mask: | |
| os.remove(tmp_mask_path) | |
| return pruned_model | |
| else: | |
| mask = fixed_mask | |
| os.remove(tmp_mask_path) | |
| return pruned_model, mask | |
| def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float, | |
| model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False): | |
| """Get the pruned model via L1 Filter Pruning. | |
| Reference: | |
| Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016. | |
| Args: | |
| model (torch.nn.Module): A PyTorch model. | |
| pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned. | |
| sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher. | |
| model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. | |
| device (str): Typically be 'cpu' or 'cuda'. | |
| verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX) | |
| need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False. | |
| dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False. | |
| Returns: | |
| torch.nn.Module: Pruned model. | |
| """ | |
| model = copy.deepcopy(model).to(device) | |
| if sparsity == 0: | |
| return model | |
| pruned_model = copy.deepcopy(model).to(device) | |
| # generate mask | |
| model.eval() | |
| if pruned_layers_name is not None: | |
| config_list = [{ | |
| 'op_types': ['Conv2d', 'ConvTranspose2d'], | |
| 'op_names': pruned_layers_name, | |
| 'sparsity': sparsity | |
| }] | |
| else: | |
| config_list = [{ | |
| 'op_types': ['Conv2d', 'ConvTranspose2d'], | |
| 'sparsity': sparsity | |
| }] | |
| pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware, | |
| dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None) | |
| pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask) | |
| return pruned_model | |