File size: 3,773 Bytes
9d66a40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | import torch
import os
import numpy as np
import math
import sys
from typing import Iterable, Optional
import torch
from dataset import score
import utils
from scipy.special import softmax
from dataset import utils_data, score
from einops import rearrange
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from aurora import Batch, Metadata
from aurora.normalisation import normalise_surf_var, normalise_atmos_var, unnormalise_surf_var, unnormalise_atmos_var
from datetime import timedelta
import pandas as pd
from tqdm import tqdm
def calculate_the_importance_label_dynamic(cuda_device, grad_type, optimizer=None):
gradients_dict = {}
for fish_params in optimizer.param_groups:
for fish_param, fish_name in zip(fish_params['params'], fish_params['names']):
# if 'backbone' in fish_name and 'bias' in fish_name:
gradients_dict[fish_name] = torch.zeros_like(fish_param).to(cuda_device)
if grad_type == "absolute":
grad_method = torch.abs
elif grad_type == "square":
grad_method = torch.square
for fish_params in optimizer.param_groups:
for fish_param, fish_name in zip(fish_params['params'], fish_params['names']):
# if 'backbone' in fish_name and 'bias' in fish_name:
gradients_dict[fish_name] += grad_method(fish_param.grad).data
return gradients_dict
def create_mask_gradient_dynamic(keep_ratio, grad_type='absolute', optimizer=None, weight=None):
original_device = "cuda" if torch.cuda.is_available() else "cpu"
cuda_device = "cuda" if torch.cuda.is_available() else "cpu"
importance_method = calculate_the_importance_label_dynamic
gradients = importance_method(cuda_device, grad_type, optimizer=optimizer)
# add sizes and aggregate tensors
sizes = {}
tensors = []
classifier_size = 0
all_params_size = 0
classifier_mask_dict = {}
for k, v in gradients.items():
# don't count classifier layer, they should be all trainable
if "classifier" in k:
classifier_size += torch.prod(torch.tensor(v.shape)).item()
classifier_mask_dict[k] = torch.ones_like(v).to(original_device)
else:
sizes[k] = v.shape
tensors.append(v.view(-1))
all_params_size += torch.prod(torch.tensor(v.shape)).item()
tensors = torch.cat(tensors, 0)
keep_num = int(all_params_size * keep_ratio) - classifier_size
assert keep_num > 0
# random noise
tensors = tensors / tensors.max()
tensors_noise = tensors * weight * torch.rand_like(tensors)
top_pos = torch.topk(tensors_noise, keep_num)[1]
masks = torch.zeros_like(tensors_noise, device=cuda_device)
masks[top_pos] = 1
assert masks.long().sum() == len(top_pos)
mask_dict = {}
now_idx = 0
for k, v in sizes.items():
end_idx = now_idx + torch.prod(torch.tensor(v))
mask_dict[k] = masks[now_idx: end_idx].reshape(v).to(original_device)
now_idx = end_idx
assert now_idx == len(masks)
# Add the classifier's mask to mask_dict
mask_dict.update(classifier_mask_dict)
# Print the parameters for checking
classifier_size = 0
all_params_size = 0
pretrain_weight_size = 0
for k, v in mask_dict.items():
if "classifier" in k:
classifier_size += (v == 1).sum().item()
else:
pretrain_weight_size += (v == 1).sum().item()
all_params_size += torch.prod(torch.tensor(v.shape)).item()
# print(pretrain_weight_size, classifier_size, all_params_size)
# print(f"trainable parameters: {(pretrain_weight_size + classifier_size) / all_params_size * 100} %")
return mask_dict |