| import torch |
| from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel |
| import re |
| import onnxruntime |
| import numpy as np |
| torch.set_num_threads(1) |
| def flatten_list(nested_list): |
| flat_list = [] |
| for element in nested_list: |
| if isinstance(element, list): |
| flat_list.extend(flatten_list(element)) |
| else: |
| flat_list.append(element) |
|
|
| return flat_list |
|
|
| class PredictionModule: |
| def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"): |
| self.session = onnxruntime.InferenceSession(model_path) |
| self.input_name = self.session.get_inputs()[0].name |
|
|
| |
| self.mean = 6.51286529169358 |
| self.scale = 1.5614094578916633 |
|
|
| def convert_to_affinity(self, normalized): |
| return { |
| "neg_log10_affinity_M": (normalized * self.scale) + self.mean, |
| "affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean))) |
| } |
|
|
| def predict(self, batch_data): |
| """Run predictions on a batch of data.""" |
| |
| batch_data = np.array([t.numpy() for t in batch_data]) |
|
|
| |
| affinities = [] |
| for feature in batch_data: |
| |
| affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0] |
| |
| affinities.append(self.convert_to_affinity(affinity_normalized)) |
|
|
| return affinities |
|
|
| class Plapt: |
| def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'): |
| |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| |
| self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) |
| self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device) |
|
|
| |
| self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
| self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device) |
|
|
| self.caching = caching |
| self.cache = {} |
|
|
| |
| self.prediction_module = PredictionModule(prediction_module_path) |
|
|
| def set_prediction_module(self, prediction_module_path): |
| self.prediction_module = PredictionModule(prediction_module_path) |
|
|
| @staticmethod |
| def preprocess_sequence(seq): |
| |
| return " ".join(re.sub(r"[UZOB]", "X", seq)) |
|
|
| def tokenize(self, mol_smiles): |
| |
| mol_tokens = self.mol_tokenizer(mol_smiles, |
| padding=True, |
| max_length=278, |
| truncation=True, |
| return_tensors='pt') |
| return mol_tokens |
| |
| def tokenize_prot(self, prot_seq): |
| |
| prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq), |
| padding=True, |
| max_length=3200, |
| truncation=True, |
| return_tensors='pt') |
|
|
| return prot_tokens |
|
|
| |
| @staticmethod |
| def make_batches(iterable, n=1): |
| length = len(iterable) |
| for ndx in range(0, length, n): |
| yield iterable[ndx:min(ndx + n, length)] |
| |
| def predict_affinity(self, prot_seq, mol_smiles, batch_size=2): |
| input_strs = mol_smiles |
|
|
| prot_tokens = self.tokenize_prot(prot_seq) |
| with torch.no_grad(): |
| prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu() |
| prot_representations = prot_representations.squeeze(0) |
| |
| prot_representations = [prot_representations for i in range(batch_size)] |
|
|
| affinities = [] |
| for batch in self.make_batches(input_strs, batch_size): |
| batch_key = str(batch) |
|
|
| if batch_key in self.cache and self.caching: |
| |
| features = self.cache[batch_key] |
| else: |
| |
| mol_tokens = self.tokenize(batch) |
| with torch.no_grad(): |
| mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu() |
| mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])] |
|
|
| features = [torch.cat((prot, mol), dim=0) for prot, mol in |
| zip(prot_representations, mol_representations)] |
|
|
| if self.caching: |
| self.cache[batch_key] = features |
|
|
| affinities.extend(self.prediction_module.predict(features)) |
|
|
| return affinities |
| |
| def score_candidates(self, target_protein, mol_smiles, batch_size=2): |
| target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)], |
| padding=True, |
| max_length=3200, |
| truncation=True, |
| return_tensors='pt') |
| |
| with torch.no_grad(): |
| target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu() |
|
|
| print(target_representation) |
|
|
| affinities = [] |
| for mol in mol_smiles: |
| mol_tokens = self.mol_tokenizer(mol, |
| padding=True, |
| max_length=278, |
| truncation=True, |
| return_tensors='pt') |
| |
| with torch.no_grad(): |
| mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu() |
|
|
| print(mol_representations) |
|
|
| features = torch.cat((target_representation[0], mol_representations[0]), dim=0) |
|
|
| print(features) |
|
|
| affinities.extend(self.prediction_module.predict([features])) |
|
|
| return affinities |
| |
| def get_cached_features(self): |
| return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))] |
|
|
| def clear_cache(self): |
| self.cache = {} |
|
|