File size: 962 Bytes
acbef3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
# 1. 加载完整的状态字典
state_dict_path = '/home/auralray/Downloads/ImageMol.pth.tar'
pretrained_state_dict = torch.load(state_dict_path)['state_dict']
# print(pretrained_state_dict)
# 2. 定义源模块的前缀
# 根据你的实际键名,可能是 'embedding_layer.' 或 'self.embedding_layer.'
# 我们以 'embedding_layer.' 为例进行演示
old_prefix = "embedding_layer."
# 3. 提取并清理键名
q_encoder_weights = {}
for key, value in pretrained_state_dict.items():
if key.startswith(old_prefix):
# 移除前缀,得到模块内部的相对键名(如 '0.weight', '1.bias')
# 新模型 q_encoder 内部的层结构和参数命名与 embedding_layer 内部一致
new_key = key[len(old_prefix):]
q_encoder_weights[new_key] = value
print(f"成功提取并清理了 {len(q_encoder_weights)} 个键。")
# print(q_encoder_weights)
torch.save(q_encoder_weights, './ImageMolEncoder.pth') |