flexpert / Flexpert-Design /src /interface /pretrain_interface.py
Honzus24's picture
initial commit
7968cb0
import torch
from omegaconf import OmegaConf
from transformers import AutoTokenizer, EsmForMaskedLM
import torch.nn.functional as F
class PretrainInterface(torch.nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
if name == "ESM35M":
self.esm_dim = 480
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t12_35M_UR50D")
if name == "ESM650M":
self.esm_dim = 1280
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c")
if name == "ESM3B":
self.esm_dim = 2560
self.tokenizer = AutoTokenizer.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
self.pretrain_model = EsmForMaskedLM.from_pretrained("/huyuqi/model_zoom/transformers/models--facebook--esm2_t36_3B_UR50D/snapshots/476b639933c8baad5ad09a60ac1a87f987b656fc")
if name == "vanilla":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-36-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict, strict=False)
# if name == "LFQ":
# from step1_VQ.model_interface import MInterface
# pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/configs/10-17T15-46-37-project.yaml")
# pretrain_args.diffusion = False
# self.pretrain_model = MInterface(**pretrain_args)
# ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/LFQ_seg_linear/checkpoints/best-epoch=14-val_loss=0.161.pth', map_location=torch.device('cpu'))
# state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
# self.pretrain_model.load_state_dict(state_dict, strict=False)
if name == "softgroup-1":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/configs/12-16T14-57-28-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftGroup/softgroup-1/checkpoints/best-epoch=13-val_loss=0.111.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "softgroup-2":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/configs/10-24T12-51-57-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-2/checkpoints/best-epoch=14-val_loss=0.067.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "softgroup-3":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/configs/10-25T00-04-15-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-3/checkpoints/best-epoch=14-val_loss=0.063.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "softgroup-4":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/configs/10-19T01-03-55-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_32_vectors/checkpoints/best-epoch=14-val_loss=0.056.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict, strict=False)
if name == "softgroup-5":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/configs/10-27T17-15-56-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup-5-gzy/checkpoints/best-epoch=14-val_loss=0.039.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "softgroup-6":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "softgroup_128_group":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/configs/10-28T01-28-50-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoftGroup/softgroup_128_group/checkpoints/best-epoch=14-val_loss=0.011.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "diff-softgroup-1":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/configs/12-17T14-19-21-project.yaml")
pretrain_args.diffusion = True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-rm-dist/checkpoints/best-epoch=12-val_loss=0.496.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "diff-softgroup-4":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/configs/12-19T01-54-15-project.yaml")
pretrain_args.diffusion = True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq32/checkpoints/best-epoch=13-val_loss=0.184.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "diff-softgroup-5":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/configs/12-19T01-57-07-project.yaml")
pretrain_args.diffusion = True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq64/checkpoints/best-epoch=13-val_loss=0.100.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "diff-softgroup-6":
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/configs/12-19T10-47-37-project.yaml")
pretrain_args.diffusion = True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/DiffESMSoftGroup/diff-softgroup-vq128/checkpoints/best-epoch=13-val_loss=0.081.pth', map_location=torch.device('cpu'))
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'vanilla-1':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/configs/10-18T01-15-37-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMVQ/base/checkpoints/best-epoch=14-val_loss=0.314.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'soft-1':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/configs/12-10T12-38-16-project.yaml")
pretrain_args.diffusion=False
pretrain_args.attn_type = 'raw'
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoft/soft_rerun/checkpoints/best-epoch=14-val_loss=0.018.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'soft_64_vecs':
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/configs/10-19T11-11-58-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMSoft/soft_vq_num64/checkpoints/best-epoch=14-val_loss=8.768.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'LFQ':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/configs/10-24T01-36-37-project.yaml")
pretrain_args.diffusion = False
pretrain_args.attn_type = 'raw'
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/DiffSDS/Pretrain_lightning/results/ESMFVQ/vanilla_L1loss/checkpoints/best-epoch=14-val_loss=11.328.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict, strict=False)
if name == 'SCQ-mlp3-vqdim32':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/configs/12-22T07-52-47-project.yaml")
pretrain_args.diffusion = False
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 3, False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp3-vqdim32/checkpoints/best-epoch=14-val_loss=0.376.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp3-vqdim32-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/configs/12-22T10-44-46-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp3-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.454.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp6-vqdim32-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/configs/12-22T18-28-04-project.yaml")
pretrain_args.diffusion = False
pretrain_args.attn_type = 'raw'
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.148.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp2-vqdim32':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/configs/12-22T00-21-35-project.yaml")
pretrain_args.diffusion = False
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-mlp2-vqdim32/checkpoints/best-epoch=14-val_loss=0.362.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp2-vqdim32-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/configs/12-22T00-06-35-project.yaml")
pretrain_args.diffusion = False
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 32, 2, True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere-vqdim32/checkpoints/best-epoch=14-val_loss=0.338.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp2-vqdim16':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/configs/12-21T13-13-11-project.yaml")
pretrain_args.diffusion = False
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional/checkpoints/best-epoch=14-val_loss=0.094.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp2-vqdim16-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/configs/12-21T16-38-57-project.yaml")
pretrain_args.diffusion = False
pretrain_args.vq_dim, pretrain_args.condition_layer, pretrain_args.sphere = 16, 2, True
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq16-conditional-sphere/checkpoints/best-epoch=14-val_loss=1.080.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-vq8-mlp6-vqdim16-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/configs/12-23T05-15-56-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-vq8-mlp6BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.892.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'SCQ-mlp9-vqdim32-sphere':
from step1_VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/configs/12-23T16-20-07-project.yaml")
pretrain_args.diffusion = False
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step1_VQ/results/ESMSoftBV/SoftBV-mlp9BN-vqdim32-sphere/checkpoints/best-epoch=14-val_loss=0.151.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'AF2VQ':
from step3_AF2VQ.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/configs/12-13T07-59-50-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load("/huyuqi/xmyu/VQProteinFormer/step3_AF2VQ/results/AF2VQ_softgroup16/checkpoints/best-epoch=11-val_loss=0.812.pth")
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == "ProGLM":
self.vq_dim=480
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/version_4/hparams.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_1127/checkpoints/best-epoch=08-valid_acc=0.804.ckpt', map_location=torch.device('cpu'))['state_dict']
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_softgroup_af2db':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/version_3/hparams.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/DiffSDS/Inpainting_representation/results/softgroup_bin_2/checkpoints/best-epoch=13-valid_acc=0.863.ckpt', map_location=torch.device('cpu'))['state_dict']
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftVQ_cath':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/configs/12-25T01-20-35-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftVQ_epoch15_pad300/checkpoints/best-epoch=27-valid_acc=0.001.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_cath':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/configs/12-25T01-42-37-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE/checkpoints/best-epoch=14-valid_acc=0.614.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_cath_inpaint':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/configs/12-25T07-47-52-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_pad300_BCE_inpaint/checkpoints/best-epoch=14-valid_acc=0.616.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_AF2DB':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/configs/12-25T13-01-12-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_epoch15_AF2DB/checkpoints/best-epoch=14-valid_acc=0.631.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_ESM1B_CATH':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/configs/12-25T16-02-35-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGLM_SoftCVQ_ESM1B_CATH_lr5e-5/checkpoints/best-epoch=14-valid_acc=0.616.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_CATH':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/configs/12-26T08-13-41-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH/checkpoints/best-epoch=14-gpt_acc=0.758.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'ProGLM_SoftCVQ_CATH_epoch10k':
from step2_ProGLM.model.model_interface import MInterface
pretrain_args = OmegaConf.load("/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch1000/configs/12-27T02-36-49-project.yaml")
self.pretrain_model = MInterface(**pretrain_args)
ckpt = torch.load('/huyuqi/xmyu/VQProteinFormer/step2_ProGLM/results/ProGPT_SoftCVQ_CATH_epoch10000_resume/checkpoints/best-epoch=1887-gpt_loss=0.181.pth')
state_dict = {k.replace('_forward_module.', ''): v for k, v in ckpt.items()}
self.pretrain_model.load_state_dict(state_dict)
if name == 'GearNet':
from model.PretrainGearNet import PretrainGearNet_Model
self.pretrain_model = PretrainGearNet_Model()
self.pretrain_model.eval()
def get_vq_id(self, seqs, angles, attn_mask):
# if ('softgroup' in self.name) or ('LFQ' in self.name):
# h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
# h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
# vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
# return F.pad(vq_id, [0,1,0,0])
h_input = self.pretrain_model.model.input(seqs.squeeze(-1), angles)
h_enc = self.pretrain_model.model.ProteinEnc(h_input, attn_mask, None).last_hidden_state
vq_id, e_enc = self.pretrain_model.model.VQLayer.get_vq(h_enc, attn_mask, temperature=1e-5)
return vq_id
def forward(self, batch):
if self.name in ["ESM35M", "ESM650M", "ESM3B"]:
seqs, attn_mask = batch['seqs'], batch['attn_mask']
outputs = self.pretrain_model.model(input_ids=seqs[:,:,0], attention_mask=attn_mask)
pretrain_embedding = outputs.hidden_states
pretrain_embedding = pretrain_embedding.reshape(-1,self.esm_dim)[attn_mask.view(-1)==1]
return pretrain_embedding
if self.name in ["softgroup_128_group"]:
seqs, angles, attn_mask = batch['seqs'], batch['angles'] , batch['attn_mask']
vq_id = self.pretrain_model.model.get_vqid(seqs[...,0], angles, attn_mask)
return vq_id
if self.name in ["ProGLM"]:
vq_id, attn_mask, seg, pos = batch['vq_id'], batch['attn_mask'], batch['seg'], batch['pos']
feat = self.pretrain_model.model.get_feat(vq_id, attn_mask, seg, pos)
feat = feat.reshape(-1,self.vq_dim)[attn_mask.view(-1)==1]
return feat
if self.name in ["GearNet"]:
seqs = batch['seqs']
batch = batch['batch']
attn_mask = batch['attn_mask']
for idx in range(seqs.shape[0]):
seq_str = self.pretrain_featurizer.ESM_tokenizer.decode(seqs[idx,attn_mask[idx,:].bool(),0])
seq_strs.append(seq_str.split(" "))
seq_strs = sum(seq_strs, [])
node_index = torch.arange(batch.batch.shape[0], device=batch.batch.device)
node2graph = batch.batch
chain_id = torch.ones_like(batch.batch)
pretrain_embedding = self.pretrain_gearnet_model(seq_strs, node_index, node2graph, chain_id, batch.pos)
return pretrain_embedding