import torch # 1. 加载完整的状态字典 state_dict_path = '/home/auralray/Downloads/ImageMol.pth.tar' pretrained_state_dict = torch.load(state_dict_path)['state_dict'] # print(pretrained_state_dict) # 2. 定义源模块的前缀 # 根据你的实际键名,可能是 'embedding_layer.' 或 'self.embedding_layer.' # 我们以 'embedding_layer.' 为例进行演示 old_prefix = "embedding_layer." # 3. 提取并清理键名 q_encoder_weights = {} for key, value in pretrained_state_dict.items(): if key.startswith(old_prefix): # 移除前缀,得到模块内部的相对键名(如 '0.weight', '1.bias') # 新模型 q_encoder 内部的层结构和参数命名与 embedding_layer 内部一致 new_key = key[len(old_prefix):] q_encoder_weights[new_key] = value print(f"成功提取并清理了 {len(q_encoder_weights)} 个键。") # print(q_encoder_weights) torch.save(q_encoder_weights, './ImageMolEncoder.pth')