| import sys |
| import os |
| import xgboost as xgb |
| import torch |
| import numpy as np |
| import warnings |
| import numpy as np |
| from rdkit import Chem, rdBase, DataStructs |
| from transformers import AutoTokenizer, EsmModel |
|
|
| rdBase.DisableLog('rdApp.error') |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| class Solubility: |
| def __init__(self): |
| |
| self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json') |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| self.model.eval() |
| |
| def generate_embeddings(self, sequences): |
| """Generate ESM embeddings for protein sequences""" |
| embeddings = [] |
| |
| |
| batch_size = 8 |
| for i in range(0, len(sequences), batch_size): |
| batch_sequences = sequences[i:i + batch_size] |
| |
| inputs = self.tokenizer( |
| batch_sequences, |
| padding=True, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| |
| if torch.cuda.is_available(): |
| inputs = {k: v.cuda() for k, v in inputs.items()} |
| self.model = self.model.cuda() |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| |
| last_hidden_states = outputs.last_hidden_state |
| |
| |
| attention_mask = inputs['attention_mask'].unsqueeze(-1) |
| masked_hidden_states = last_hidden_states * attention_mask |
| sum_hidden_states = masked_hidden_states.sum(dim=1) |
| seq_lengths = attention_mask.sum(dim=1) |
| batch_embeddings = sum_hidden_states / seq_lengths |
| |
| batch_embeddings = batch_embeddings.cpu().numpy() |
| embeddings.append(batch_embeddings) |
| |
| if embeddings: |
| return np.vstack(embeddings) |
| else: |
| return np.array([]) |
| |
| def get_scores(self, input_seqs: list): |
| scores = np.zeros(len(input_seqs)) |
| features = self.generate_embeddings(input_seqs) |
| |
| if len(features) == 0: |
| return scores |
| |
| features = np.nan_to_num(features, nan=0.) |
| features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| |
| features = xgb.DMatrix(features) |
| |
| scores = self.predictor.predict(features) |
| return scores |
| |
| def __call__(self, input_seqs: list): |
| scores = self.get_scores(input_seqs) |
| return scores |
| |
| def unittest(): |
| solubility = Solubility() |
| sequences = [ |
| "GLSKGCFGLKLDRIGSMSGLGC", |
| "RGLSDGFLKLKMGISGSLGC" |
| ] |
| |
| scores = solubility(input_seqs=sequences) |
| print(scores) |
| |
| if __name__ == '__main__': |
| unittest() |