Tong Chen
add files
295b1cd
import sys
import os
sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles')
import xgboost as xgb
import torch
import numpy as np
from transformers import AutoModelForMaskedLM
from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
import warnings
import numpy as np
from rdkit import Chem, rdBase, DataStructs
rdBase.DisableLog('rdApp.error')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class Solubility:
def __init__(self, device):
self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json')
self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device)
self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt',
'/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt')
self.device = device
def generate_embeddings(self, sequences):
embeddings = []
for sequence in sequences:
tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device)
with torch.no_grad():
output = self.emb_model(**tokenized)
# Mean pooling across sequence length
embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()
embeddings.append(embedding)
return np.array(embeddings)
def get_scores(self, input_seqs: list):
scores = np.zeros(len(input_seqs))
features = self.generate_embeddings(input_seqs)
if len(features) == 0:
return scores
features = np.nan_to_num(features, nan=0.)
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
features = xgb.DMatrix(features)
scores = self.predictor.predict(features)
return scores
def __call__(self, input_seqs: list):
scores = self.get_scores(input_seqs)
return torch.tensor(scores)
def unittest():
solubility = Solubility(device='cuda:6')
seq = ["N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H]([C@H](O)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)O"]
scores = solubility(input_seqs=seq)
print(scores)
if __name__ == '__main__':
unittest()