| import torch | |
| # 1. 加载完整的状态字典 | |
| state_dict_path = '/home/auralray/Downloads/ImageMol.pth.tar' | |
| pretrained_state_dict = torch.load(state_dict_path)['state_dict'] | |
| # print(pretrained_state_dict) | |
| # 2. 定义源模块的前缀 | |
| # 根据你的实际键名,可能是 'embedding_layer.' 或 'self.embedding_layer.' | |
| # 我们以 'embedding_layer.' 为例进行演示 | |
| old_prefix = "embedding_layer." | |
| # 3. 提取并清理键名 | |
| q_encoder_weights = {} | |
| for key, value in pretrained_state_dict.items(): | |
| if key.startswith(old_prefix): | |
| # 移除前缀,得到模块内部的相对键名(如 '0.weight', '1.bias') | |
| # 新模型 q_encoder 内部的层结构和参数命名与 embedding_layer 内部一致 | |
| new_key = key[len(old_prefix):] | |
| q_encoder_weights[new_key] = value | |
| print(f"成功提取并清理了 {len(q_encoder_weights)} 个键。") | |
| # print(q_encoder_weights) | |
| torch.save(q_encoder_weights, './ImageMolEncoder.pth') |