| import torch |
| import os |
| import numpy as np |
| import math |
| import sys |
| from typing import Iterable, Optional |
| import torch |
| from dataset import score |
| import utils |
| from scipy.special import softmax |
| from dataset import utils_data, score |
| from einops import rearrange |
| from torch import nn |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as patches |
| from aurora import Batch, Metadata |
| from aurora.normalisation import normalise_surf_var, normalise_atmos_var, unnormalise_surf_var, unnormalise_atmos_var |
| from datetime import timedelta |
| import pandas as pd |
| from tqdm import tqdm |
|
|
|
|
| def calculate_the_importance_label_dynamic(cuda_device, grad_type, optimizer=None): |
| gradients_dict = {} |
| for fish_params in optimizer.param_groups: |
| for fish_param, fish_name in zip(fish_params['params'], fish_params['names']): |
| |
| gradients_dict[fish_name] = torch.zeros_like(fish_param).to(cuda_device) |
|
|
| if grad_type == "absolute": |
| grad_method = torch.abs |
| elif grad_type == "square": |
| grad_method = torch.square |
|
|
| for fish_params in optimizer.param_groups: |
| for fish_param, fish_name in zip(fish_params['params'], fish_params['names']): |
| |
| gradients_dict[fish_name] += grad_method(fish_param.grad).data |
|
|
| return gradients_dict |
|
|
|
|
| def create_mask_gradient_dynamic(keep_ratio, grad_type='absolute', optimizer=None, weight=None): |
| original_device = "cuda" if torch.cuda.is_available() else "cpu" |
| cuda_device = "cuda" if torch.cuda.is_available() else "cpu" |
| importance_method = calculate_the_importance_label_dynamic |
| gradients = importance_method(cuda_device, grad_type, optimizer=optimizer) |
|
|
| |
| sizes = {} |
| tensors = [] |
|
|
| classifier_size = 0 |
| all_params_size = 0 |
|
|
| classifier_mask_dict = {} |
|
|
| for k, v in gradients.items(): |
| |
| if "classifier" in k: |
| classifier_size += torch.prod(torch.tensor(v.shape)).item() |
| classifier_mask_dict[k] = torch.ones_like(v).to(original_device) |
| else: |
| sizes[k] = v.shape |
| tensors.append(v.view(-1)) |
|
|
| all_params_size += torch.prod(torch.tensor(v.shape)).item() |
|
|
| tensors = torch.cat(tensors, 0) |
|
|
| keep_num = int(all_params_size * keep_ratio) - classifier_size |
|
|
| assert keep_num > 0 |
|
|
| |
| tensors = tensors / tensors.max() |
| tensors_noise = tensors * weight * torch.rand_like(tensors) |
|
|
| top_pos = torch.topk(tensors_noise, keep_num)[1] |
|
|
| masks = torch.zeros_like(tensors_noise, device=cuda_device) |
|
|
| masks[top_pos] = 1 |
|
|
| assert masks.long().sum() == len(top_pos) |
|
|
| mask_dict = {} |
|
|
| now_idx = 0 |
| for k, v in sizes.items(): |
| end_idx = now_idx + torch.prod(torch.tensor(v)) |
| mask_dict[k] = masks[now_idx: end_idx].reshape(v).to(original_device) |
| now_idx = end_idx |
|
|
| assert now_idx == len(masks) |
|
|
| |
| mask_dict.update(classifier_mask_dict) |
|
|
| |
| classifier_size = 0 |
| all_params_size = 0 |
| pretrain_weight_size = 0 |
| |
| for k, v in mask_dict.items(): |
| if "classifier" in k: |
| classifier_size += (v == 1).sum().item() |
| else: |
| pretrain_weight_size += (v == 1).sum().item() |
|
|
| all_params_size += torch.prod(torch.tensor(v.shape)).item() |
| |
| |
| |
|
|
| return mask_dict |