File size: 7,066 Bytes
04579ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
import re
import onnxruntime
import numpy as np
torch.set_num_threads(1)
def flatten_list(nested_list):
flat_list = []
for element in nested_list:
if isinstance(element, list):
flat_list.extend(flatten_list(element))
else:
flat_list.append(element)
return flat_list
class PredictionModule:
def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"):
self.session = onnxruntime.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
# Normalization scaling parameters
self.mean = 6.51286529169358
self.scale = 1.5614094578916633
def convert_to_affinity(self, normalized):
return {
"neg_log10_affinity_M": (normalized * self.scale) + self.mean,
"affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
}
def predict(self, batch_data):
"""Run predictions on a batch of data."""
# Convert each tensor to a numpy array and store in a list
batch_data = np.array([t.numpy() for t in batch_data])
# Process each feature in the batch individually and store results
affinities = []
for feature in batch_data:
# Run the model on the single feature
affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0]
# Append the result
affinities.append(self.convert_to_affinity(affinity_normalized))
return affinities
class Plapt:
def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'):
# Set device for computation
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# Load protein tokenizer and encoder
self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)
# Load molecule tokenizer and encoder
self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)
self.caching = caching
self.cache = {}
# Load the prediction module ONNX model
self.prediction_module = PredictionModule(prediction_module_path)
def set_prediction_module(self, prediction_module_path):
self.prediction_module = PredictionModule(prediction_module_path)
@staticmethod
def preprocess_sequence(seq):
# Preprocess protein sequence
return " ".join(re.sub(r"[UZOB]", "X", seq))
def tokenize(self, mol_smiles):
# Tokenize and encode molecules
mol_tokens = self.mol_tokenizer(mol_smiles,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
return mol_tokens
def tokenize_prot(self, prot_seq):
# Tokenize and encode protein sequences
prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq),
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
return prot_tokens
# Define the batch functions
@staticmethod
def make_batches(iterable, n=1):
length = len(iterable)
for ndx in range(0, length, n):
yield iterable[ndx:min(ndx + n, length)]
def predict_affinity(self, prot_seq, mol_smiles, batch_size=2):
input_strs = mol_smiles
prot_tokens = self.tokenize_prot(prot_seq)
with torch.no_grad():
prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
prot_representations = prot_representations.squeeze(0)
# repeat for zip(prot_representations, mol_representations)
prot_representations = [prot_representations for i in range(batch_size)]
affinities = []
for batch in self.make_batches(input_strs, batch_size):
batch_key = str(batch) # Convert batch to a string to use as a dictionary key
if batch_key in self.cache and self.caching:
# Use cached features if available
features = self.cache[batch_key]
else:
# Tokenize and encode the batch, then cache the results
mol_tokens = self.tokenize(batch)
with torch.no_grad():
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])]
features = [torch.cat((prot, mol), dim=0) for prot, mol in
zip(prot_representations, mol_representations)]
if self.caching:
self.cache[batch_key] = features
affinities.extend(self.prediction_module.predict(features))
return affinities
def score_candidates(self, target_protein, mol_smiles, batch_size=2):
target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)],
padding=True,
max_length=3200,
truncation=True,
return_tensors='pt')
with torch.no_grad():
target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu()
print(target_representation)
affinities = []
for mol in mol_smiles:
mol_tokens = self.mol_tokenizer(mol,
padding=True,
max_length=278,
truncation=True,
return_tensors='pt')
with torch.no_grad():
mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
print(mol_representations)
features = torch.cat((target_representation[0], mol_representations[0]), dim=0)
print(features)
affinities.extend(self.prediction_module.predict([features]))
return affinities
def get_cached_features(self):
return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))]
def clear_cache(self):
self.cache = {}
|