import torch from safetensors import safe_open from safetensors.torch import save_file # # 1. 读取原始文件中的所有参数 head = 8 r= 4 original_tensors = {} # path = '/mnt/jfzn/msj/train_exp/mask_gdn_1B_hrr4_byt' # '/mnt/jfzn/msj/train_exp/gdn_1B_a800' # with safe_open("/mnt/jfzn/msj/train_exp/gdn_1B_a800/model.safetensors", framework="pt", device="cpu") as f: # for key in f.keys(): # original_tensors[key] = f.get_tensor(key) # import torch.nn.init as init # import math # for i in range(24): # new_tensor = torch.randn([128,2048],dtype=torch.bfloat16) # init.xavier_uniform_(new_tensor, gain=2 ** -2.5) # name = 'model.layers.'+str(i)+str('.attn.mask.weight') # original_tensors[name] = new_tensor # save_file(original_tensors, "/mnt/jfzn/msj/train_exp/gdn_1B_a800_cp/model.safetensors") with safe_open("/mnt/jfzn/msj/train_exp/mask_gdn_hrr4/model.safetensors", framework="pt", device="cpu") as f: for key in f.keys(): original_tensors[key] = f.get_tensor(key) if 'mask' in key: if '19' in key: print(key) print(original_tensors[key].shape) # w = (original_tensors[key]).cuda() # x = torch.randn([1,2048],dtype=torch.bfloat16).cuda() # from einops import rearrange, repeat # from fla.modules.l2norm import l2_norm as l2_norm_fn # target_matrix = (x@w.transpose(0,1)).abs() # target_matrix = rearrange(target_matrix,'l (h r c)->h l r c',r=r,h=8)#bhlrr # target_matrix = l2_norm_fn(target_matrix) # target_matrix = target_matrix@target_matrix.transpose(-1,-2) # print(target_matrix)