File size: 7,066 Bytes
04579ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
import re
import onnxruntime
import numpy as np
torch.set_num_threads(1)
def flatten_list(nested_list):
    flat_list = []
    for element in nested_list:
        if isinstance(element, list):
            flat_list.extend(flatten_list(element))
        else:
            flat_list.append(element)

    return flat_list

class PredictionModule:
    def __init__(self, model_path="models/affinity_predictor0734-seed2101.onnx"):
        self.session = onnxruntime.InferenceSession(model_path)
        self.input_name = self.session.get_inputs()[0].name

        # Normalization scaling parameters
        self.mean = 6.51286529169358
        self.scale = 1.5614094578916633

    def convert_to_affinity(self, normalized):
        return  {
                    "neg_log10_affinity_M": (normalized * self.scale) + self.mean,
                    "affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean)))
                }

    def predict(self, batch_data):
        """Run predictions on a batch of data."""
        # Convert each tensor to a numpy array and store in a list
        batch_data = np.array([t.numpy() for t in batch_data])

        # Process each feature in the batch individually and store results
        affinities = []
        for feature in batch_data:
            # Run the model on the single feature
            affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0]
            # Append the result
            affinities.append(self.convert_to_affinity(affinity_normalized))

        return affinities

class Plapt:
    def __init__(self, prediction_module_path = "models/affinity_predictor0734-seed2101.onnx", caching=True, device='cuda'):
        # Set device for computation
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')

        # Load protein tokenizer and encoder
        self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)

        # Load molecule tokenizer and encoder
        self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)

        self.caching = caching
        self.cache = {}

        # Load the prediction module ONNX model
        self.prediction_module = PredictionModule(prediction_module_path)

    def set_prediction_module(self, prediction_module_path):
        self.prediction_module = PredictionModule(prediction_module_path)

    @staticmethod
    def preprocess_sequence(seq):
        # Preprocess protein sequence
        return " ".join(re.sub(r"[UZOB]", "X", seq))

    def tokenize(self, mol_smiles):
        # Tokenize and encode molecules
        mol_tokens = self.mol_tokenizer(mol_smiles,
                                        padding=True,
                                        max_length=278,
                                        truncation=True,
                                        return_tensors='pt')
        return mol_tokens
    
    def tokenize_prot(self, prot_seq):
        # Tokenize and encode protein sequences
        prot_tokens = self.prot_tokenizer(self.preprocess_sequence(prot_seq),
                                          padding=True,
                                          max_length=3200,
                                          truncation=True,
                                          return_tensors='pt')

        return prot_tokens

    # Define the batch functions
    @staticmethod
    def make_batches(iterable, n=1):
        length = len(iterable)
        for ndx in range(0, length, n):
            yield iterable[ndx:min(ndx + n, length)]
    
    def predict_affinity(self, prot_seq, mol_smiles, batch_size=2):
        input_strs = mol_smiles

        prot_tokens = self.tokenize_prot(prot_seq)
        with torch.no_grad():
            prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu()
        prot_representations = prot_representations.squeeze(0)
        # repeat for zip(prot_representations, mol_representations)
        prot_representations = [prot_representations for i in range(batch_size)]

        affinities = []
        for batch in self.make_batches(input_strs, batch_size):
            batch_key = str(batch)  # Convert batch to a string to use as a dictionary key

            if batch_key in self.cache and self.caching:
                # Use cached features if available
                features = self.cache[batch_key]
            else:
                # Tokenize and encode the batch, then cache the results
                mol_tokens = self.tokenize(batch)
                with torch.no_grad():
                    mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()
                    mol_representations = [mol_representations[i, :] for i in range(mol_representations.shape[0])]

                features = [torch.cat((prot, mol), dim=0) for prot, mol in 
                            zip(prot_representations, mol_representations)]

                if self.caching:
                    self.cache[batch_key] = features

            affinities.extend(self.prediction_module.predict(features))

        return affinities
    
    def score_candidates(self, target_protein, mol_smiles, batch_size=2):
        target_tokens = self.prot_tokenizer([self.preprocess_sequence(target_protein)],
                                            padding=True,
                                            max_length=3200,
                                            truncation=True,
                                            return_tensors='pt')
        
        with torch.no_grad():
            target_representation = self.prot_encoder(**target_tokens.to(self.device)).pooler_output.cpu()

        print(target_representation)

        affinities = []
        for mol in mol_smiles:
            mol_tokens = self.mol_tokenizer(mol,
                                                padding=True,
                                                max_length=278,
                                                truncation=True,
                                                return_tensors='pt')
        
            with torch.no_grad():
                mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu()

            print(mol_representations)

            features = torch.cat((target_representation[0], mol_representations[0]), dim=0)

            print(features)

            affinities.extend(self.prediction_module.predict([features]))

        return affinities
    
    def get_cached_features(self):
        return [tensor.tolist() for tensor in flatten_list(list(self.cache.values()))]

    def clear_cache(self):
        self.cache = {}