File size: 2,578 Bytes
5e90249 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import xgboost as xgb
import torch
import numpy as np
from transformers import AutoModelForMaskedLM
import warnings
import numpy as np
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class Solubility:
def __init__(self, tokenizer, base_path, device=None, emb_model=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.predictor = xgb.Booster(model_file=f'{base_path}/TR2-D2/tr2d2-pep/scoring/functions/classifiers/solubility-xgboost.json')
if emb_model is not None:
self.emb_model = emb_model.to(self.device).eval()
else:
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(self.device).eval()
self.tokenizer = tokenizer
def generate_embeddings(self, sequences):
embeddings = []
for sequence in sequences:
tokenized = self.tokenizer(sequence, return_tensors='pt')
tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
with torch.no_grad():
output = self.emb_model(**tokenized)
# Mean pooling across sequence length
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
embeddings.append(embedding)
return np.array(embeddings)
def get_scores(self, input_seqs: list):
scores = np.zeros(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
return scores
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return scores
def unittest():
solubility = Solubility()
seq = ["NCC(=O)N[C@H](CS)C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)N[C@@H](c1ccc(cc1)F)C(=O)N[C@@H]([C@H](CC)C)C(=O)N[C@@H](CCCO)C(=O)N[C@@H](CC1=CN=C-N1)C(=O)N[C@@H](CCC(=O)O)C(=O)N[C@@H](CO)C(=O)O"]
scores = solubility(input_seqs=seq)
print(scores)
if __name__ == '__main__':
unittest() |