File size: 962 Bytes
acbef3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch

# 1. 加载完整的状态字典
state_dict_path = '/home/auralray/Downloads/ImageMol.pth.tar' 
pretrained_state_dict = torch.load(state_dict_path)['state_dict']
# print(pretrained_state_dict)

# 2. 定义源模块的前缀
# 根据你的实际键名,可能是 'embedding_layer.' 或 'self.embedding_layer.'
# 我们以 'embedding_layer.' 为例进行演示
old_prefix = "embedding_layer." 

# 3. 提取并清理键名
q_encoder_weights = {}
for key, value in pretrained_state_dict.items():
    if key.startswith(old_prefix):
        # 移除前缀,得到模块内部的相对键名(如 '0.weight', '1.bias')
        # 新模型 q_encoder 内部的层结构和参数命名与 embedding_layer 内部一致
        new_key = key[len(old_prefix):]
        q_encoder_weights[new_key] = value

print(f"成功提取并清理了 {len(q_encoder_weights)} 个键。")
# print(q_encoder_weights)
torch.save(q_encoder_weights, './ImageMolEncoder.pth')