Spaces:
Running
Running
| from abc import abstractmethod | |
| from copy import deepcopy | |
| import enum | |
| import torch | |
| from torch import nn | |
| import os | |
| from .model_fbs import DomainDynamicConv2d | |
| #from methods.utils.data import get_source_dataloader, get_source_normal_aug_dataloader, get_target_dataloaders | |
| #from models.resnet_cifar.model_manager import ResNetCIFARManager | |
| from utils.common.others import get_cur_time_str | |
| from utils.dl.common.env import set_random_seed | |
| from utils.dl.common.model import get_model_latency, get_model_size, get_module, set_module | |
| from utils.common.log import logger | |
| from utils.third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
| from utils.third_party.nni_new.compression.pytorch.utils.mask_conflict import GroupMaskConflict, ChannelMaskConflict, CatMaskPadding | |
| def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None, fix_group=False, fix_channel=True, fix_padding=False): | |
| if isinstance(masks, str): | |
| # if the input is the path of the mask_file | |
| assert os.path.exists(masks) | |
| masks = torch.load(masks) | |
| assert len(masks) > 0, 'Mask tensor cannot be empty' | |
| # if the user uses the model and dummy_input to trace the model, we | |
| # should get the traced model handly, so that, we only trace the | |
| # model once, GroupMaskConflict and ChannelMaskConflict will reuse | |
| # this traced model. | |
| if traced is None: | |
| assert model is not None and dummy_input is not None | |
| training = model.training | |
| model.eval() | |
| # We need to trace the model in eval mode | |
| traced = torch.jit.trace(model, dummy_input) | |
| model.train(training) | |
| if fix_group: | |
| fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced) | |
| masks = fix_group_mask.fix_mask() | |
| if fix_channel: | |
| fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced) | |
| masks = fix_channel_mask.fix_mask() | |
| if fix_padding: | |
| padding_cat_mask = CatMaskPadding(masks, model, dummy_input, traced) | |
| masks = padding_cat_mask.fix_mask() | |
| return masks | |
| class FeatureBoosting(nn.Module): | |
| def __init__(self, w: torch.Tensor): | |
| super(FeatureBoosting, self).__init__() | |
| assert w.dim() == 1 | |
| self.w = nn.Parameter(w.unsqueeze(0).unsqueeze(2).unsqueeze(3), requires_grad=False) | |
| def forward(self, x): | |
| return x * self.w | |
| class FBSSubModelExtractor: | |
| def extract_submodel_via_a_sample(self, fbs_model: nn.Module, sample: torch.Tensor): | |
| assert sample.dim() == 4 and sample.size(0) == 1 | |
| fbs_model.eval() | |
| o1 = fbs_model(sample) | |
| pruning_info = {} | |
| pruning_masks = {} | |
| for layer_name, layer in fbs_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
| w = get_module(fbs_model, layer_name).cached_w.squeeze() | |
| unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
| pruning_info[layer_name] = w | |
| cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
| pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
| no_gate_model = deepcopy(fbs_model) | |
| for name, layer in no_gate_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| # layer.bn.weight.data.mul_(pruning_info[name]) | |
| set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
| # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
| tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
| torch.save(pruning_masks, tmp_mask_path) | |
| pruned_model = no_gate_model | |
| pruned_model.eval() | |
| model_speedup = ModelSpeedup(pruned_model, sample, tmp_mask_path, sample.device) | |
| model_speedup.speedup_model() | |
| os.remove(tmp_mask_path) | |
| # add feature boosting module | |
| for layer_name, feature_boosting_w in pruning_info.items(): | |
| feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
| set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
| pruned_model_size = get_model_size(pruned_model, True) | |
| pruned_model.eval() | |
| o2 = pruned_model(sample) | |
| diff = ((o1 - o2) ** 2).sum() | |
| logger.info(f'pruned model size: {pruned_model_size:.3f}MB, diff: {diff}') | |
| return pruned_model | |
| def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): | |
| pass | |
| def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): | |
| pass | |
| def extract_submodel_via_samples(self, fbs_model: nn.Module, samples: torch.Tensor): | |
| assert samples.dim() == 4 | |
| fbs_model = deepcopy(fbs_model) | |
| # fbs_model.eval() | |
| # fbs_model(samples) | |
| self.generate_pruning_strategy(fbs_model, samples) | |
| pruning_info = {} | |
| pruning_masks = {} | |
| for layer_name, layer in fbs_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
| w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim | |
| w = self.get_final_w(fbs_model, samples, layer_name, w) | |
| unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
| pruning_info[layer_name] = w | |
| cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
| pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
| no_gate_model = deepcopy(fbs_model) | |
| for name, layer in no_gate_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| # layer.bn.weight.data.mul_(pruning_info[name]) | |
| set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
| # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
| tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
| torch.save(pruning_masks, tmp_mask_path) | |
| pruned_model = no_gate_model | |
| pruned_model.eval() | |
| model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) | |
| model_speedup.speedup_model() | |
| os.remove(tmp_mask_path) | |
| # add feature boosting module | |
| for layer_name, feature_boosting_w in pruning_info.items(): | |
| feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
| set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
| return pruned_model, pruning_info | |
| def extract_submodel_via_samples_and_last_submodel(self, fbs_model: nn.Module, samples: torch.Tensor, | |
| last_submodel: nn.Module, last_pruning_info: dict): | |
| assert samples.dim() == 4 | |
| fbs_model = deepcopy(fbs_model) | |
| # fbs_model.eval() | |
| # fbs_model(samples) | |
| self.generate_pruning_strategy(fbs_model, samples) | |
| pruning_info = {} | |
| pruning_masks = {} | |
| # some tricks | |
| incrementally_updated_layers = [] | |
| for layer_name, layer in fbs_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| cur_pruning_mask = {'weight': torch.zeros_like(layer.raw_conv2d.weight.data)} | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'] = torch.zeros_like(layer.raw_conv2d.bias.data) | |
| w = get_module(fbs_model, layer_name).cached_w.squeeze() # 2-dim | |
| w = self.get_final_w(fbs_model, samples, layer_name, w) | |
| unpruned_filters_index = w.nonzero(as_tuple=True)[0] | |
| pruning_info[layer_name] = w | |
| cur_pruning_mask['weight'][unpruned_filters_index, ...] = 1. | |
| if layer.raw_conv2d.bias is not None: | |
| cur_pruning_mask['bias'][unpruned_filters_index, ...] = 1. | |
| pruning_masks[layer_name + '.0'] = cur_pruning_mask | |
| # some tricks | |
| if last_pruning_info is not None: | |
| last_w = last_pruning_info[layer_name] | |
| intersection_ratio = ((w > 0) * (last_w > 0)).sum() / (last_w > 0).sum() | |
| if intersection_ratio > 0.: | |
| incrementally_updated_layers += [layer_name] # that is, only similar layers are transferable | |
| no_gate_model = deepcopy(fbs_model) | |
| for name, layer in no_gate_model.named_modules(): | |
| if not isinstance(layer, DomainDynamicConv2d): | |
| continue | |
| # layer.bn.weight.data.mul_(pruning_info[name]) | |
| set_module(no_gate_model, name, nn.Sequential(layer.raw_conv2d, layer.bn, nn.Identity())) | |
| # fixed_pruning_masks = fix_mask_conflict(pruning_masks, fbs_model, sample.size(), None, True, True, True) | |
| tmp_mask_path = f'tmp_mask_{get_cur_time_str()}_{os.getpid()}.pth' | |
| torch.save(pruning_masks, tmp_mask_path) | |
| pruned_model = no_gate_model | |
| pruned_model.eval() | |
| model_speedup = ModelSpeedup(pruned_model, samples[0:1], tmp_mask_path, samples.device) | |
| model_speedup.speedup_model() | |
| os.remove(tmp_mask_path) | |
| # add feature boosting module | |
| for layer_name, feature_boosting_w in pruning_info.items(): | |
| feature_boosting_w = feature_boosting_w[feature_boosting_w.nonzero(as_tuple=True)[0]] | |
| set_module(pruned_model, layer_name + '.2', FeatureBoosting(feature_boosting_w)) | |
| # some tricks | |
| # incrementally updating (borrow some weights from last_pruned_model) | |
| for layer_name in incrementally_updated_layers: | |
| cur_filter_i, last_filter_i = 0, 0 | |
| for i, (w_factor, last_w_factor) in enumerate(zip(pruning_info[layer_name], last_pruning_info[layer_name])): | |
| if w_factor > 0 and last_w_factor > 0: # the filter is shared | |
| cur_conv2d, last_conv2d = get_module(pruned_model, layer_name + '.0'), get_module(last_submodel, layer_name + '.0') | |
| cur_conv2d.weight.data[cur_filter_i] = last_conv2d.weight.data[last_filter_i] | |
| cur_bn, last_bn = get_module(pruned_model, layer_name + '.1'), get_module(last_submodel, layer_name + '.1') | |
| cur_bn.weight.data[cur_filter_i] = last_bn.weight.data[last_filter_i] | |
| cur_bn.bias.data[cur_filter_i] = last_bn.bias.data[last_filter_i] | |
| cur_bn.running_mean.data[cur_filter_i] = last_bn.running_mean.data[last_filter_i] | |
| cur_bn.running_var.data[cur_filter_i] = last_bn.running_var.data[last_filter_i] | |
| cur_fw, last_fw = get_module(pruned_model, layer_name + '.2'), get_module(last_submodel, layer_name + '.2') | |
| cur_fw.w.data[0, cur_filter_i] = last_fw.w.data[0, last_filter_i] | |
| if w_factor > 0: | |
| cur_filter_i += 1 | |
| if last_w_factor > 0: | |
| last_filter_i += 1 | |
| return pruned_model, pruning_info | |
| def absorb_sub_model(self, fbs_model: nn.Module, sub_model: nn.Module, pruning_info: dict, alpha=1.): | |
| if alpha == 0.: | |
| return | |
| for layer_name, feature_boosting_w in pruning_info.items(): | |
| unpruned_filters_index = feature_boosting_w.nonzero(as_tuple=True)[0] | |
| fbs_layer = get_module(fbs_model, layer_name) | |
| sub_model_layer = get_module(sub_model, layer_name) | |
| for fi_in_sub_layer, fi_in_fbs_layer in enumerate(unpruned_filters_index): | |
| fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] = (1. - alpha) * fbs_layer.raw_conv2d.weight.data[fi_in_fbs_layer] + \ | |
| alpha * sub_model_layer[0].weight.data[fi_in_sub_layer] | |
| for k in ['weight', 'bias', 'running_mean', 'running_var']: | |
| getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] = (1. - alpha) * getattr(fbs_layer.bn, k).data[fi_in_fbs_layer] + \ | |
| alpha * getattr(sub_model_layer[1], k).data[fi_in_sub_layer] | |
| class DAFBSSubModelExtractor(FBSSubModelExtractor): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| # self.debug_sample_i = 0 | |
| # self.last_final_ws = None | |
| def generate_pruning_strategy(self, fbs_model: nn.Module, samples: torch.Tensor): | |
| with torch.no_grad(): | |
| fbs_model.eval() | |
| self.cur_output = fbs_model(samples) | |
| def get_final_w(self, fbs_model: nn.Module, samples: torch.Tensor, layer_name: str, w: torch.Tensor): | |
| # import matplotlib.pyplot as plt | |
| # plt.imshow(w.cpu().numpy(), cmap='Greys') | |
| # # plt.colorbar() | |
| # plt.xlabel('Filters') | |
| # plt.ylabel('Samples') | |
| # plt.tight_layout() | |
| # plt.savefig(os.path.join(res_save_dir, f'{layer_name}.png'), dpi=300) | |
| # plt.clf() | |
| # w_sum = w.sum(0) | |
| # w_argsort = w_sum.argsort(descending=True) | |
| # return w[self.debug_sample_i] | |
| # x = self.cur_output | |
| # each_sample_entropy = -(x.softmax(1) * x.log_softmax(1)).sum(1) | |
| # hardest_sample_index = w.sum(1).argmax() | |
| # return w[hardest_sample_index] | |
| # [0.0828, 0.1017, 0.0575, 0.3081, 0.1511, 0.3634, 0.3388, 0.3942, 0.2475, 0.3371, 0.5837, 0.145, 0.4428, 0.2159, 0.4028] 0.27815999999999996 | |
| x = self.cur_output | |
| each_sample_entropy = -(x.logits.softmax(1) * x.logits.log_softmax(1)).sum(1) | |
| hardest_sample_index = each_sample_entropy.argmax() | |
| res = w[hardest_sample_index] | |
| return res | |
| # if self.last_final_ws is not None: | |
| # intersection_ratio = (self.last_final_w == res).sum() / (res > 0).sum() | |
| # print('intersection ratio: ', intersection_ratio) | |
| # self.last_final_ws[layer_name] = res | |
| # indices = (-w).sum(0).topk((w[0] == 0).sum())[1] | |
| # boosting = w.max(0)[0] | |
| # boosting[indices] = 0. | |
| # return boosting | |
| # return w[0] | |
| def tent_as_detector(model, x, num_iters=1, lr=1e-4, l1_wd=0., strategy='ours'): | |
| model = deepcopy(model) | |
| before_model = deepcopy(model) | |
| from methods.tent import tent | |
| optimizer = torch.optim.SGD( | |
| model.parameters(), lr=lr, weight_decay=l1_wd) | |
| from models.resnet_cifar.model_manager import ResNetCIFARManager | |
| tented_model = tent.Tent(model, optimizer, ResNetCIFARManager, steps=num_iters) | |
| tent.configure_model(model) | |
| tented_model(x) | |
| filters_sen_info = {} | |
| last_conv_name = None | |
| for (name, m1), m2 in zip(model.named_modules(), before_model.modules()): | |
| if isinstance(m1, nn.Conv2d): | |
| last_conv_name = name | |
| if not isinstance(m1, nn.BatchNorm2d): | |
| continue | |
| with torch.no_grad(): | |
| features_weight_diff = ((m1.weight.data - m2.weight.data).abs()) | |
| features_bias_diff = ((m1.bias.data - m2.bias.data).abs()) | |
| features_diff = features_weight_diff + features_bias_diff | |
| features_diff_order = features_diff.argsort(descending=False) | |
| if strategy == 'ours': | |
| untrained_filters_index = features_diff_order[: int(len(features_diff) * 0.8)] | |
| elif strategy == 'random': | |
| untrained_filters_index = torch.randperm(len(features_diff))[: int(len(features_diff) * 0.8)] | |
| elif strategy == 'inversed_ours': | |
| untrained_filters_index = features_diff_order.flip(0)[: int(len(features_diff) * 0.8)] | |
| elif strategy == 'none': | |
| untrained_filters_index = None | |
| filters_sen_info[name] = dict(untrained_filters_index=untrained_filters_index, conv_name=last_conv_name) | |
| return filters_sen_info | |
| class SGDF(torch.optim.SGD): | |
| def step(self, model, conv_filters_sen_info, filters_sen_info, closure=None): | |
| """Performs a single optimization step. | |
| Arguments: | |
| closure (callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| weight_decay = group['weight_decay'] | |
| momentum = group['momentum'] | |
| dampening = group['dampening'] | |
| nesterov = group['nesterov'] | |
| # assert len([i for i in model.named_parameters()]) == len([j for j in group['params']]) | |
| for (name, _), p in zip(model.named_parameters(), group['params']): | |
| if p.grad is None: | |
| continue | |
| layer_name = '.'.join(name.split('.')[0:-1]) | |
| if layer_name in filters_sen_info.keys(): | |
| untrained_filters_index = filters_sen_info[layer_name]['untrained_filters_index'] | |
| elif layer_name in conv_filters_sen_info.keys(): | |
| untrained_filters_index = conv_filters_sen_info[layer_name]['untrained_filters_index'] | |
| else: | |
| untrained_filters_index = [] | |
| d_p = p.grad | |
| if weight_decay != 0: | |
| d_p = d_p.add(p, alpha=weight_decay) | |
| if momentum != 0: | |
| param_state = self.state[p] | |
| if 'momentum_buffer' not in param_state: | |
| buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() | |
| else: | |
| buf = param_state['momentum_buffer'] | |
| buf.mul_(momentum).add_(d_p, alpha=1 - dampening) | |
| if nesterov: | |
| d_p = d_p.add(buf, alpha=momentum) | |
| else: | |
| d_p = buf | |
| d_p[untrained_filters_index] = 0. | |
| p.add_(d_p, alpha=-group['lr']) | |
| return loss | |
| if __name__ == '__main__': | |
| set_random_seed(0) | |
| import sys | |
| tag = sys.argv[1] | |
| # alpha = 0.4 | |
| alpha = 0.2 | |
| # alpha = float(sys.argv[1]) | |
| fbs_model_path = sys.argv[1] | |
| cur_time_str = get_cur_time_str() | |
| res_save_dir = f'logs/experiments_trial/CIFAR100C/ours_fbs_more_challenging/{cur_time_str[0:8]}/{cur_time_str[8:]}-{tag}' | |
| os.makedirs(res_save_dir) | |
| import shutil | |
| shutil.copytree(os.path.dirname(__file__), | |
| os.path.join(res_save_dir, 'method'), ignore=shutil.ignore_patterns('*.pt', '*.pth', 'log', '__pycache__')) | |
| logger.info(f'res save dir: {res_save_dir}') | |
| # model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/152138-0.6_l1wd=1e-8/best_model_0.80.pt') | |
| # model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220801/232913-sample_subnetwork/best_model_0.80.pt') | |
| model = torch.load(fbs_model_path) | |
| # model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220729/002444-0.4/best_model_0.40.pt') | |
| # import sys | |
| # sys.path.append('/data/xgf/legodnn_and_domain_adaptation') | |
| xgf_model = torch.load('logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220731/224212-cifar10_svhn_raw/last_model.pt') | |
| # xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt') | |
| # test_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'test', False, False, False) | |
| # test_dataloader = get_target_dataloaders('CIFAR100C', [7], 128, 4, 'test', False, False, False)[0] # snow, xgf 0.3914 | |
| # test_dataloaders = get_target_dataloaders('CIFAR100C', list(range(15)), 128, 4, 'test', False, False, False) # defocus_blur, xgf 0.2836 | |
| # test_dataloaders = get_target_dataloaders('RotatedCIFAR100', list(range(18)), 128, 4, 'test', False, False, False) | |
| train_dataloaders = [ | |
| get_source_dataloader(dataset_name, 128, 4, 'train', True, None, True) for dataset_name in ['SVHN', 'CIFAR10', 'SVHN'] | |
| ][::-1] * 10 | |
| test_dataloaders = [ | |
| get_source_dataloader('USPS', 128, 4, 'test', False, False, False), | |
| get_source_dataloader('STL10-wo-monkey', 128, 4, 'test', False, False, False), | |
| get_source_dataloader('MNIST', 128, 4, 'test', False, False, False), | |
| ][::-1] * 10 | |
| y_offsets = [10, 0, 10][::-1] * 10 | |
| domain_names = ['USPS', 'STL10', 'MNIST'][::-1] * 10 | |
| # train_dataloader = get_source_dataloader('CIFAR100', 128, 4, 'train', True, None, True) | |
| # acc = ResNetCIFARManager.get_accuracy(model, test_dataloader, 'cuda') | |
| # print(acc) | |
| # baseline_accs = [0.1012, 0.1156, 0.0529, 0.2836, 0.1731, 0.3765, 0.3445, 0.3914, 0.2672, 0.3289, 0.5991, 0.1486, 0.4519, 0.1907, 0.3929] | |
| # accs = [] | |
| baseline_before, baseline_after, ours_before, ours_after = [], [], [], [] | |
| last_pruned_model, last_pruning_info = None, None | |
| # y_offset = 0 | |
| for ti, (test_dataloader, y_offset) in enumerate(zip(test_dataloaders, y_offsets)): | |
| samples, labels = next(iter(test_dataloader)) | |
| samples, labels = samples.cuda(), labels.cuda() | |
| labels += y_offset | |
| def bn_cal(_model: nn.Module): | |
| for n, m in _model.named_modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.reset_running_stats() | |
| m.training = True | |
| m.train() | |
| for _ in range(100): # ~one epoch | |
| x, y = next(train_dataloaders[ti]) | |
| x = x.cuda() | |
| _model(samples) | |
| def shot(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
| # print([n for n, p in model.named_parameters()]) | |
| _model.requires_grad_(True) | |
| _model.linear.requires_grad_(False) | |
| import torch.optim | |
| optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
| device = 'cuda' | |
| for _ in range(100 * num_iters_scale): | |
| x = samples | |
| _model.train() | |
| output = ResNetCIFARManager.forward(_model, x) | |
| def Entropy(input_): | |
| entropy = -input_ * torch.log(input_ + 1e-5) | |
| entropy = torch.sum(entropy, dim=1) | |
| return entropy | |
| softmax_out = nn.Softmax(dim=1)(output) | |
| entropy_loss = torch.mean(Entropy(softmax_out)) | |
| msoftmax = softmax_out.mean(dim=0) | |
| entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) | |
| loss = entropy_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| def shot_w_part_filter(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
| # print([n for n, p in model.named_parameters()]) | |
| _model.requires_grad_(True) | |
| _model.linear.requires_grad_(False) | |
| import torch.optim | |
| optimizer = SGDF([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
| device = 'cuda' | |
| filters_sen_info = tent_as_detector(_model, samples, strategy='ours') | |
| conv_filters_sen_info = {v['conv_name']: v for _, v in filters_sen_info.items()} | |
| for _ in range(100 * num_iters_scale): | |
| x = samples | |
| _model.train() | |
| output = ResNetCIFARManager.forward(_model, x) | |
| def Entropy(input_): | |
| entropy = -input_ * torch.log(input_ + 1e-5) | |
| entropy = torch.sum(entropy, dim=1) | |
| return entropy | |
| softmax_out = nn.Softmax(dim=1)(output) | |
| entropy_loss = torch.mean(Entropy(softmax_out)) | |
| msoftmax = softmax_out.mean(dim=0) | |
| entropy_loss -= torch.sum(-msoftmax * torch.log(msoftmax + 1e-5)) | |
| loss = entropy_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step(_model, conv_filters_sen_info, filters_sen_info) | |
| def tent(_model: nn.Module): | |
| from methods.tent import tent | |
| _model = tent.configure_model(_model) | |
| params, param_names = tent.collect_params(_model) | |
| optimizer = torch.optim.Adam(params, lr=1e-4) | |
| tent_model = tent.Tent(_model, optimizer, ResNetCIFARManager, steps=1) | |
| tent.configure_model(_model) | |
| tent_model(samples) | |
| def tent_configure_bn(_model): | |
| """Configure model for use with tent.""" | |
| # train mode, because tent optimizes the model to minimize entropy | |
| # _model.train() | |
| # # disable grad, to (re-)enable only what tent updates | |
| # _model.requires_grad_(False) | |
| # configure norm for tent updates: enable grad + force batch statisics | |
| for m in _model.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.requires_grad_(True) | |
| # force use of batch stats in train and eval modes | |
| m.track_running_stats = False | |
| m.running_mean = None | |
| m.running_var = None | |
| # m.track_running_stats = True | |
| # m.momentum = 1.0 | |
| # # FIXME | |
| # from methods.ours_dynamic_filters.extract_submodel import FeatureBoosting | |
| # # if isinstance(m, FeatureBoosting): | |
| # if m.__class__.__name__ == 'FeatureBoosting': | |
| # m.requires_grad_(True) | |
| return model | |
| def sl(_model: nn.Module, lr=6e-4, num_iters_scale=1, wd=0.): | |
| _model.requires_grad_(True) | |
| _model.linear.requires_grad_(False) | |
| import torch.optim | |
| optimizer = torch.optim.SGD([p for p in _model.parameters() if p.requires_grad], lr=lr, momentum=0.9, weight_decay=wd) | |
| device = 'cuda' | |
| for _ in range(100 * num_iters_scale): | |
| x = samples | |
| _model.train() | |
| loss = ResNetCIFARManager.forward_to_gen_loss(_model, x, labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| model_extractor = DAFBSSubModelExtractor() | |
| model1 = model_extractor.extract_submodel_via_a_sample(model,samples[0]) | |
| pruned_model, pruning_info = model_extractor.extract_submodel_via_samples_and_last_submodel(model, samples, None, None) | |
| # print(pruned_model) | |
| # print(get_model_size(pruned_model, True)) | |
| # bn_cal(pruned_model) | |
| acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) | |
| print(acc) | |
| ours_before += [acc] | |
| # tent(pruned_model) | |
| # bn_cal(pruned_model) | |
| shot_w_part_filter(pruned_model, 6e-4, 1, 1e-3) | |
| # sl(pruned_model) | |
| acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda', y_offset) | |
| print(acc) | |
| ours_after += [acc] | |
| last_pruned_model, last_pruning_info = deepcopy(pruned_model), deepcopy(pruning_info) | |
| model_extractor.absorb_sub_model(model, pruned_model, pruning_info, alpha) | |
| # xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s8/20220607/212448/last_model.pt') | |
| # xgf_model = torch.load('/data/xgf/legodnn_and_domain_adaptation/results_scaling_da/image_classification/CIFAR100C_resnet18/onda/offline_l1/s4/20220607/204211/last_model.pt') | |
| # print(xgf_model) | |
| # acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset) | |
| # print(acc) | |
| # baseline_before += [acc] | |
| # # tent(xgf_model) | |
| # shot(xgf_model) | |
| # # sl(xgf_model) | |
| # acc = ResNetCIFARManager.get_accuracy(xgf_model, test_dataloader, 'cuda', y_offset) | |
| # print(acc) | |
| # baseline_after += [acc] | |
| # print() | |
| # diff = acc - baseline_accs[ti] | |
| # print(f'domain {ti}, model size {get_model_size(pruned_model, True):.3f}MB, diff: {diff:.4f}') | |
| # print(accs, sum(accs) / len(accs)) | |
| import matplotlib.pyplot as plt | |
| from visualize.util import * | |
| set_figure_settings(3) | |
| def avg(arr): | |
| return sum(arr) / len(arr) | |
| # plt.plot(list(range(len(test_dataloaders))), baseline_before, lw=2, linestyle='--', color=BLUE, label=f'L1 before DA ({avg(baseline_before):.4f})') | |
| # plt.plot(list(range(len(test_dataloaders))), baseline_after, lw=2, linestyle='-', color=BLUE, label=f'L1 after DA ({avg(baseline_after):.4f})') | |
| plt.plot(list(range(len(test_dataloaders))), ours_before, lw=2, linestyle='--', color=RED, label=f'ours before DA ({avg(ours_before):.4f})') | |
| plt.plot(list(range(len(test_dataloaders))), ours_after, lw=2, linestyle='-', color=RED, label=f'ours after DA ({avg(ours_after):.4f})') | |
| plt.xlabel('domains') | |
| plt.ylabel('accuracy') | |
| plt.xticks(list(range(len(domain_names))), domain_names, rotation=90) | |
| plt.legend(loc=2, bbox_to_anchor=(1.05, 1.0), fontsize=16) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(res_save_dir, 'main.png'), dpi=300) | |
| plt.clf() | |
| torch.save((baseline_before, baseline_after, ours_before, ours_after), os.path.join(res_save_dir, 'main.png.data')) | |
| # with open('./tmp.csv', 'a') as f: | |
| # f.write(f'{alpha:.2f},{avg(baseline_after):.4f},{avg(ours_after):.4f}') | |
| # std: logs/experiments_trial/CIFAR100C/ours_dynamic_filters/20220730/161404-submodel/main.png | |
| # accs = [] | |
| # for i in tqdm.tqdm(range(100)): | |
| # model_extractor.debug_sample_i = i | |
| # pruned_model = model_extractor.extract_submodel_via_samples(model, samples) | |
| # acc = ResNetCIFARManager.get_accuracy(pruned_model, test_dataloader, 'cuda') | |
| # accs += [acc] | |
| # import matplotlib.pyplot as plt | |
| # plt.plot(list(range(100)), accs) | |
| # plt.savefig('./tmp.png', dpi=300) | |
| # plt.clf() | |
| # ------------------------------ | |
| # perf test | |
| # sample, _ = next(iter(test_dataloader)) | |
| # sample = sample[0: 1].cuda() | |
| # pruned_model = FBSSubModelExtractor().extract_submodel_via_a_sample(model, sample) | |
| # bs = 1 | |
| # def perf_test(model, batch_size, device): | |
| # model = model.to(device) | |
| # optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
| # # warmup | |
| # for _ in range(100): | |
| # rand_input = torch.rand((batch_size, 3, 32, 32)).to(device) | |
| # o = model(rand_input) | |
| # forward_latency = 0. | |
| # backward_latency = 0. | |
| # for _ in range(100): | |
| # rand_input = torch.rand((batch_size, 3, 32, 32)).to(device) | |
| # optimizer.zero_grad() | |
| # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| # s.record() | |
| # o = model(rand_input) | |
| # e.record() | |
| # torch.cuda.synchronize() | |
| # forward_latency += s.elapsed_time(e) / 1000. | |
| # loss = ((o - 1) ** 2).sum() | |
| # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| # s.record() | |
| # loss.backward() | |
| # optimizer.step() | |
| # e.record() | |
| # torch.cuda.synchronize() | |
| # backward_latency += s.elapsed_time(e) / 1000. | |
| # forward_latency /= 100 | |
| # backward_latency /= 100 | |
| # print(forward_latency, backward_latency) | |
| # for bs in [1, 128]: | |
| # for device in ['cuda', 'cpu']: | |
| # for m in [model, pruned_model]: | |
| # print(bs, device) | |
| # perf_test(m, bs, device) |