| | import sys; sys.path.append('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom') |
| | import torch |
| | import torch.nn as nn |
| | import os |
| | import numpy as np |
| | from tqdm import tqdm |
| | from src.model.pretrain_modules import ( |
| | ESM2Model, SmilesModel, ESM3Model, ESMC600MModel, ProCyonModel, |
| | GearNetModel, ProLLAMAModel, ProSTModel, ProtGPT2Model, ProTrekModel, |
| | SaPortModel, VenusPLMModel, ProSST2048Model, ProGen2Model, ProstT5Model, |
| | ProtT5, DPLMModel, OntoProteinModel, ANKHBase, PGLMModel,BIOModel |
| | ) |
| | MODEL_ZOOM_PATH = '/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom' |
| | os.environ["TOKENIZERS_PARALLELISM"] = "true" |
| |
|
| |
|
| | class PretrainModelInterface(nn.Module): |
| | """ |
| | setup_model: 注册预训练模型 |
| | construct_batch: 构建模型的输入batch, 需要进行padding。如果序列加上EOS, BOS, 结构也需要在这里进行padding |
| | forward: 调用不同模型提取蛋白质氨基酸级别的embedding, 在后处理阶段去除EOS, BOS embedding,只返回氨基酸embedding |
| | """ |
| |
|
| | def __init__(self, pretrain_model_name, batch_size = 64, max_length = 1022, device = 'cuda', sequence_only=False, task_type=None): |
| | super(PretrainModelInterface, self).__init__() |
| | self.pretrain_model_name = pretrain_model_name |
| | self.sequence_only = sequence_only |
| | self.batch_size = batch_size |
| | self.task_type = task_type |
| | self.device = device |
| | self.start, self.end = 1, -1 |
| | self.setup_model() |
| | |
| | def setup_model(self): |
| | """ |
| | Setup the pre-trained model based on the specified name. |
| | ['esm2_650m', 'esm3_1.4b', 'esmc_600m', 'procyon', 'prollama', 'progen2', 'prostt5', 'protgpt2', 'protrek', 'saport', 'gearnet', 'prost', 'prosst2048', 'venusplm'] |
| | """ |
| | device = 'cuda' |
| | self.smiles_model = SmilesModel(device) |
| | if self.pretrain_model_name == 'esm2_650m': |
| | self.pretrain_model = ESM2Model(device) |
| | elif self.pretrain_model_name == 'esm2_35m': |
| | self.pretrain_model = ESM2Model(device, model_path='esm2_35m') |
| | elif self.pretrain_model_name == 'esm2_150m': |
| | self.pretrain_model = ESM2Model(device, model_path='esm2_150m') |
| | elif self.pretrain_model_name == 'esm2_3b': |
| | self.pretrain_model = ESM2Model(device, model_path='esm2_3b') |
| | elif self.pretrain_model_name == 'esm2_15b': |
| | self.pretrain_model = ESM2Model(device, model_path='esm2_15b') |
| | elif self.pretrain_model_name == 'esm3_1.4b': |
| | self.pretrain_model = ESM3Model(device) |
| | elif self.pretrain_model_name == 'esmc_600m': |
| | self.pretrain_model = ESMC600MModel(device) |
| | elif self.pretrain_model_name == 'procyon': |
| | self.pretrain_model = ProCyonModel(device) |
| | elif self.pretrain_model_name == 'prollama': |
| | self.pretrain_model = ProLLAMAModel(device) |
| | elif self.pretrain_model_name == 'progen2': |
| | self.pretrain_model = ProGen2Model(device) |
| | elif self.pretrain_model_name == 'prostt5': |
| | self.pretrain_model = ProstT5Model(device) |
| | elif self.pretrain_model_name == 'protgpt2': |
| | self.pretrain_model = ProtGPT2Model(device) |
| | self.start, self.end = 0, -1 |
| | elif self.pretrain_model_name == 'protrek_35m': |
| | self.pretrain_model = ProTrekModel(device, model_path='protrek_35m') |
| | elif self.pretrain_model_name == 'protrek': |
| | self.pretrain_model = ProTrekModel(device) |
| | elif self.pretrain_model_name == 'saport': |
| | self.pretrain_model = SaPortModel(device) |
| | elif self.pretrain_model_name == 'saport_1.3b': |
| | self.pretrain_model = SaPortModel(device, model_path='saprot_1.3b') |
| | elif self.pretrain_model_name == 'saport_35m': |
| | self.pretrain_model = SaPortModel(device, model_path='saprot_35m') |
| | elif self.pretrain_model_name == 'gearnet': |
| | self.pretrain_model = GearNetModel(device) |
| | elif self.pretrain_model_name == 'prost': |
| | self.pretrain_model = ProSTModel(device) |
| | elif self.pretrain_model_name == 'prosst2048': |
| | self.pretrain_model = ProSST2048Model(device) |
| | elif self.pretrain_model_name == 'venusplm': |
| | self.pretrain_model = VenusPLMModel(device) |
| | elif self.pretrain_model_name == 'prott5': |
| | self.pretrain_model = ProtT5(device) |
| | elif self.pretrain_model_name == 'dplm': |
| | self.pretrain_model = DPLMModel(device) |
| | elif self.pretrain_model_name == 'dplm_150m': |
| | self.pretrain_model = DPLMModel(device, model_path='dplm_150m') |
| | elif self.pretrain_model_name == 'dplm_3b': |
| | self.pretrain_model = DPLMModel(device, model_path='dplm_3b') |
| | elif self.pretrain_model_name == 'ontoprotein': |
| | self.pretrain_model = OntoProteinModel(device) |
| | elif self.pretrain_model_name == "ankh_base": |
| | self.pretrain_model = ANKHBase(device) |
| | elif self.pretrain_model_name == "pglm": |
| | self.pretrain_model = PGLMModel(device) |
| | elif self.pretrain_model_name == "pglm-3b": |
| | self.pretrain_model = PGLMModel(device, model_path="proteinglm-3b-mlm") |
| | elif self.pretrain_model_name == "pretrain_bio": |
| | self.pretrian_model = BIOModel(device,model_path="/nas/shared/kilab/wangyujia/ProtT3/all_checkpoints/stage2_06301657/last.ckpt/converted.ckpt") |
| | |
| | def setup_peft(self, peft_type="lora", **kwargs): |
| | if self.pretrain_model is None: |
| | raise RuntimeError("pretrained model is not initialized, please initial it first.") |
| | self.pretrain_model.setup_peft( |
| | peft_type=peft_type, |
| | **kwargs |
| | ) |
| |
|
| | @torch.no_grad() |
| | def inference_datasets(self, data, task_name=None): |
| | self.pretrain_model.eval() |
| | proccessed_data = [] |
| | for i in tqdm(range(0, len(data), self.batch_size), desc='Extracting embeddings'): |
| | if "|" in data[0]['seq']: |
| | samples_A, samples_B = [], [] |
| | for sample in data[i:i + self.batch_size]: |
| | sample_A = {key: value for key, value in sample.items() if key != 'seq'} |
| | sample_B = {key: value for key, value in sample.items() if key != 'seq'} |
| | sample_A['seq'] = sample['seq'].split('|')[0] |
| | sample_B['seq'] = sample['seq'].split('|')[1] |
| | if 'pdb_path' in sample: |
| | sample_A['pdb_path'] = sample['pdb_path'].split('|')[0] |
| | sample_B['pdb_path'] = sample['pdb_path'].split('|')[1] |
| | sample_A['X'] = sample['X'][0] |
| | sample_B['X'] = sample['X'][1] |
| | samples_A.append(sample_A) |
| | samples_B.append(sample_B) |
| |
|
| | batch_A = self.pretrain_model.construct_batch(samples_A) |
| | batch_B = self.pretrain_model.construct_batch(samples_B) |
| | results_A = self.pretrain_model(batch_A, task_type=self.task_type) |
| | results_B = self.pretrain_model(batch_B, task_type=self.task_type) |
| | results = [] |
| | for idx in range(len(results_A)): |
| | result = {} |
| | for key in results_A[0].keys(): |
| | PAD = torch.ones_like(results_A[idx]['embedding'])[:1] |
| | if key == 'embedding': |
| | result[key] = torch.cat([results_A[idx]['embedding'], PAD, results_B[idx]['embedding']], dim=0) |
| | elif key == 'attention_mask': |
| | result[key] = torch.cat([results_A[idx]['attention_mask'], torch.tensor([True]), results_B[idx]['attention_mask']], dim=0)==1 |
| | else: |
| | result[key] = results_A[idx][key] |
| | results.append(result) |
| | else: |
| | samples = data[i:i + self.batch_size] |
| | batch = self.pretrain_model.construct_batch(samples) |
| | results = self.pretrain_model(batch, task_type=self.task_type) |
| | |
| | if samples[0].get('smiles') is not None: |
| | batch_smi = self.smiles_model.construct_batch(samples) |
| | for idx in range(len(samples)): |
| | results[idx]['smiles'] = batch_smi['smiles'][idx] |
| | |
| | proccessed_data.extend(results) |
| | |
| | return proccessed_data |
| | |
| | def forward(self, data): |
| | if "|" in data[0]['seq']: |
| | samples_A, samples_B = [], [] |
| | for sample in data: |
| | sample_A = {key: value for key, value in sample.items() if key != 'seq'} |
| | sample_B = {key: value for key, value in sample.items() if key != 'seq'} |
| | sample_A['seq'] = sample['seq'].split('|')[0] |
| | sample_B['seq'] = sample['seq'].split('|')[1] |
| | if 'pdb_path' in sample: |
| | sample_A['pdb_path'] = sample['pdb_path'].split('|')[0] |
| | sample_B['pdb_path'] = sample['pdb_path'].split('|')[1] |
| | sample_A['X'] = sample['X'][0] |
| | sample_B['X'] = sample['X'][1] |
| | samples_A.append(sample_A) |
| | samples_B.append(sample_B) |
| |
|
| | batch_A = self.pretrain_model.construct_batch(samples_A) |
| | batch_B = self.pretrain_model.construct_batch(samples_B) |
| | embedding_A = self.pretrain_model(batch_A, task_type=self.task_type, post_process=False)[:,self.start:self.end,:] |
| | embedding_B = self.pretrain_model(batch_B, task_type=self.task_type, post_process=False)[:,self.start:self.end,:] |
| | bs, hidden_dim = embedding_A.shape[0], embedding_A.shape[-1] |
| | PAD = torch.ones((bs, 1, hidden_dim), device=embedding_A.device) |
| | embedding = torch.cat([embedding_A, PAD, embedding_B], dim=1).contiguous() |
| | labels = torch.stack(batch_A["label"]).to(embedding.device).to(embedding.dtype) |
| | PAD_MASK = torch.ones((bs, 1), device=embedding_A.device) |
| | attention_mask = torch.cat( |
| | [batch_A['attention_mask'][:,self.start:self.end], PAD_MASK, batch_B['attention_mask'][:,self.start:self.end]], dim=1 |
| | )==1 |
| | else: |
| | batch = self.pretrain_model.construct_batch(data) |
| | embedding = self.pretrain_model(batch, task_type=self.task_type, post_process=False)[:,self.start:self.end,:] |
| | |
| | labels = torch.stack(batch["label"]).to(embedding.device).to(embedding.dtype) |
| | |
| | attention_mask = batch["attention_mask"][:,self.start:self.end]==1 |
| |
|
| | batch_smi = None |
| | if data[0].get('smiles') is not None: |
| | batch_smi = self.smiles_model.construct_batch(data) |
| | batch_smi = torch.stack(batch_smi["smiles"]).contiguous().to(embedding.device).to(embedding.dtype) |
| | |
| | if self.task_type == "contact": |
| | true_length = attention_mask.shape[-1] |
| | labels = labels[:, :true_length, :true_length] |
| | elif self.task_type == "residual_classification": |
| | true_length = attention_mask.shape[-1] |
| | labels = labels[:, :true_length] |
| | if self.pretrain_model_name == 'prostt5': |
| | attention_mask = attention_mask[::2] |
| | return embedding, labels, attention_mask, batch_smi |
| |
|