| import os | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| import torch | |
| import pytorch_lightning as pl | |
| from transformers import AutoTokenizer, AutoModelWithLMHead, EsmModel | |
| from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, accuracy_score | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from models import * | |
| import pdb | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| from transformers.utils import logging | |
| logging.set_verbosity_error() | |
| def parse_motif(motif: str) -> list: | |
| parts = motif.split(',') | |
| result = [] | |
| for part in parts: | |
| part = part.strip() | |
| if '-' in part: | |
| start, end = map(int, part.split('-')) | |
| result.extend(range(start, end + 1)) | |
| else: | |
| result.append(int(part)) | |
| return result | |
| class BindEvaluator(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(BindEvaluator, self).__init__() | |
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval() | |
| for param in self.esm_model.parameters(): | |
| param.requires_grad = False | |
| self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval() | |
| for param in self.chemberta_model.parameters(): | |
| param.requires_grad = False | |
| self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden, | |
| cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model, | |
| cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout) | |
| self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.output_projection_prot = nn.Linear(cfg.model.d_model, 1) | |
| self.class_weights = torch.tensor([7.803861120886587, 0.5342284711965681]) # binding_site weights, non-bidning site weights | |
| self.cfg = cfg | |
| self._total_steps = None | |
| def forward(self, binder_tokens, target_tokens): | |
| peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state | |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state | |
| binder_mask = binder_tokens["attention_mask"] # [B, Ls] | |
| target_mask = target_tokens["attention_mask"] # [B, Lp] | |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| prot_seq_attention_list, seq_prot_attention_list = self.repeated_module( | |
| peptide_sequence, | |
| protein_sequence, | |
| peptide_mask=binder_mask, | |
| protein_mask=target_mask, | |
| ) | |
| # final cross-attention: protein queries attend to binder keys | |
| prot_enc, final_prot_seq_attention = self.final_attention_layer( | |
| prot_enc, sequence_enc, sequence_enc, | |
| key_padding_mask=binder_mask, | |
| query_padding_mask=target_mask, | |
| ) | |
| prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask) | |
| prot_enc = self.output_projection_prot(prot_enc) | |
| return prot_enc | |
| def compute_metrics(true_residues, predicted_residues, length): | |
| # Initialize the true and predicted lists with 0 | |
| true_list = [0] * length | |
| predicted_list = [0] * length | |
| # Set the values to 1 based on the provided lists | |
| for index in true_residues: | |
| true_list[index] = 1 | |
| for index in predicted_residues: | |
| predicted_list[index] = 1 | |
| # Compute the metrics | |
| accuracy = accuracy_score(true_list, predicted_list) | |
| f1 = f1_score(true_list, predicted_list) | |
| mcc = matthews_corrcoef(true_list, predicted_list) | |
| return accuracy, f1, mcc | |
| def main(): | |
| CONFIG_PATH = './config.yaml' | |
| with open(CONFIG_PATH, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| cfg = edict(config_dict) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = BindEvaluator.load_from_checkpoint(cfg.inference.ckpt, cfg=cfg, map_location=device) | |
| model.eval() | |
| binder_seq = cfg.inference.binder | |
| target_seq = cfg.inference.target | |
| # binder_seq = ['CCCCC', 'CCCCC'] | |
| # target_seq = ['AACAAAF', 'AACAAAF'] | |
| smiles_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k") | |
| esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| binder = smiles_tokenizer(binder_seq, return_tensors='pt', padding=True).to(device) | |
| target = esm_tokenizer(target_seq, return_tensors='pt', padding=True).to(device) | |
| prediction = model(binder, target).squeeze(-1) | |
| probs = torch.sigmoid(prediction) # (1, L) | |
| # print(probs) | |
| # pdb.set_trace() | |
| threshold = cfg.inference.threshold | |
| # threshold = 0.5 | |
| binding_site = [] | |
| for i in range(probs.shape[1]): | |
| if probs[0][i] >= threshold: | |
| binding_site.append(i) | |
| print("Predicted Binding Sites: ", binding_site) | |
| if cfg.inference.ground_truth is not None: | |
| print("Ground Truth Binding Sites: ", cfg.inference.ground_truth) | |
| motifs = parse_motif(cfg.inference.ground_truth) | |
| motif_scores = probs[:, motifs].mean(dim=1) | |
| non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]] | |
| mask = non_motif_probs >= threshold | |
| count = mask.sum(dim=-1) | |
| specificity = 1 - count / len(target_seq) | |
| print(f"Motif Score: {motif_scores}") | |
| print(f"Specificity Score: {specificity}") | |
| if __name__ == "__main__": | |
| main() |
Xet Storage Details
- Size:
- 5.7 kB
- Xet hash:
- 8d9b9b0f02832b39ab0ee1d009ecddad28175520e13ce40bb487b576cb267755
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.