| | import sys |
| | import os |
| | import xgboost as xgb |
| | import torch |
| | import numpy as np |
| | import warnings |
| | import numpy as np |
| | from rdkit import Chem, rdBase, DataStructs |
| | from transformers import AutoTokenizer, EsmModel |
| |
|
| | rdBase.DisableLog('rdApp.error') |
| | warnings.filterwarnings("ignore", category=DeprecationWarning) |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| |
|
| | class Solubility: |
| | def __init__(self): |
| | |
| | self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_solubility.json') |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | self.model.eval() |
| | |
| | def generate_embeddings(self, sequences): |
| | """Generate ESM embeddings for protein sequences""" |
| | embeddings = [] |
| | |
| | |
| | batch_size = 8 |
| | for i in range(0, len(sequences), batch_size): |
| | batch_sequences = sequences[i:i + batch_size] |
| | |
| | inputs = self.tokenizer( |
| | batch_sequences, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | |
| | if torch.cuda.is_available(): |
| | inputs = {k: v.cuda() for k, v in inputs.items()} |
| | self.model = self.model.cuda() |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | |
| | |
| | last_hidden_states = outputs.last_hidden_state |
| | |
| | |
| | attention_mask = inputs['attention_mask'].unsqueeze(-1) |
| | masked_hidden_states = last_hidden_states * attention_mask |
| | sum_hidden_states = masked_hidden_states.sum(dim=1) |
| | seq_lengths = attention_mask.sum(dim=1) |
| | batch_embeddings = sum_hidden_states / seq_lengths |
| | |
| | batch_embeddings = batch_embeddings.cpu().numpy() |
| | embeddings.append(batch_embeddings) |
| | |
| | if embeddings: |
| | return np.vstack(embeddings) |
| | else: |
| | return np.array([]) |
| | |
| | def get_scores(self, input_seqs: list): |
| | scores = np.zeros(len(input_seqs)) |
| | features = self.generate_embeddings(input_seqs) |
| | |
| | if len(features) == 0: |
| | return scores |
| | |
| | features = np.nan_to_num(features, nan=0.) |
| | features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| | |
| | features = xgb.DMatrix(features) |
| | |
| | scores = self.predictor.predict(features) |
| | return scores |
| | |
| | def __call__(self, input_seqs: list): |
| | scores = self.get_scores(input_seqs) |
| | return scores |
| | |
| | def unittest(): |
| | solubility = Solubility() |
| | sequences = [ |
| | "GLSKGCFGLKLDRIGSMSGLGC", |
| | "RGLSDGFLKLKMGISGSLGC" |
| | ] |
| | |
| | scores = solubility(input_seqs=sequences) |
| | print(scores) |
| | |
| | if __name__ == '__main__': |
| | unittest() |