import sys import os sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') import xgboost as xgb import torch import numpy as np from transformers import AutoModelForMaskedLM from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer import warnings import numpy as np from rdkit import Chem, rdBase, DataStructs from transformers import AutoModelForMaskedLM rdBase.DisableLog('rdApp.error') warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) class Nonfouling: def __init__(self, device): self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/nonfouling-xgboost.json') self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') self.device = device def generate_embeddings(self, sequences): embeddings = [] for sequence in sequences: tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device) with torch.no_grad(): output = self.emb_model(**tokenized) # Mean pooling across sequence length embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() embeddings.append(embedding) return np.array(embeddings) def get_scores(self, input_seqs: list): scores = np.zeros(len(input_seqs)) features = self.generate_embeddings(input_seqs) if len(features) == 0: return scores features = np.nan_to_num(features, nan=0.) features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) features = xgb.DMatrix(features) scores = self.predictor.predict(features) # return the probability of it being not hemolytic return scores def __call__(self, input_seqs: list): scores = self.get_scores(input_seqs) return torch.tensor(scores) def unittest(): nf = Nonfouling(device='cuda:6') seq = ["N[C@@H](CC(=O)O)-N[C@@H](C[C@@H](C))C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@H](CO)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CO)C(=O)N[C@H](CO)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCN)C(=O)N[C@@H](CCCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CCC(=O)N)C(=O)O"] scores = nf(input_seqs=seq) print(scores) if __name__ == '__main__': unittest()